Skip to content

Commit

Permalink
Re-implement prompt builder, make it public, publish 3.0.0 core mod…
Browse files Browse the repository at this point in the history
…ule (#292)

* feat(PromptBuilder): re-implement `PromptBuilder`

Make `PromptBuilder` assert that the client inserts all mandatory keywords.

* Remove commented code & move `PromptKeyword`'s prop into its file

* Write tests for `PromptBuilder`

* Check the inserted keyword present in the prompt template + test it

* feat(publish): Publish the updated version of `PromptBuilder`

* refactor: Update check condition in `PromptBuilder.insert` method

* publish: Publish core module with version 2.0.7

This version includes all the changes of `PromptBuilder` on commit 1284ab4.

* refactor: remove commented and redundant code in `PromptBuilder`

* refactor: replace manual keyword variable creation with a prop

* refactor: remove `text` prop of `PromptKeyword` since it duplicates `name`

* feat: generate docs via AI for `PromptBuilder`

* refactor: add newline into the prompt in `PromptBuilder`

* feat: cover method info and polymorphism insertions with tests

* fix: apply klint

* fix: use `StringBuilder` for `fullText` assembling

* publish: publish `testspark-core` with version `3.0.0`

---------

Co-authored-by: Vladislav Artiukhov <[email protected]>
  • Loading branch information
Vladislav0Art and Vladislav Artiukhov authored Sep 27, 2024
1 parent 0ba555f commit 3f0a1b3
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 121 deletions.
2 changes: 1 addition & 1 deletion core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ publishing {
create<MavenPublication>("maven") {
groupId = group as String
artifactId = "testspark-core"
version = "2.0.5"
version = "3.0.0"
from(components["java"])

artifact(tasks["sourcesJar"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,149 +2,168 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt

import org.jetbrains.research.testspark.core.data.ClassType
import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation
import java.util.EnumMap

/**
* Builds prompts by populating a template with keyword values
* and validates that all mandatory keywords are provided.
*
* @property promptTemplate The template string for the prompt.
*/
class PromptBuilder(private val promptTemplate: String) {
private val insertedKeywordValues: EnumMap<PromptKeyword, String> = EnumMap(PromptKeyword::class.java)

// collect all the keywords present in the prompt template
private val templateKeywords: List<PromptKeyword> = buildList {
for (keyword in PromptKeyword.entries) {
if (promptTemplate.contains(keyword.variable)) {
add(keyword)
}
}
}

internal class PromptBuilder(private var prompt: String) {
private fun isPromptValid(
keyword: PromptKeyword,
prompt: String,
): Boolean {
val keywordText = keyword.text
val isMandatory = keyword.mandatory
/**
* Builds the prompt by populating the template with the inserted values
* and validating that all mandatory keywords were provided.
*
* @return The built prompt.
* @throws IllegalStateException if a mandatory keyword is not present in the template.
*/
fun build(): String {
var populatedPrompt = promptTemplate

// populate the template with the inserted values
for ((keyword, value) in insertedKeywordValues.entries) {
populatedPrompt = populatedPrompt.replace(keyword.variable, value, ignoreCase = false)
}

return (prompt.contains(keywordText) || !isMandatory)
// validate that all mandatory keywords were provided
for (keyword in templateKeywords) {
if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) {
throw IllegalStateException("The prompt must contain ${keyword.name} keyword")
}
}

return populatedPrompt
}

fun insertLanguage(language: String) = apply {
if (isPromptValid(PromptKeyword.LANGUAGE, prompt)) {
val keyword = "\$${PromptKeyword.LANGUAGE.text}"
prompt = prompt.replace(keyword, language, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}")
/**
* Inserts a keyword and its corresponding value into the prompt template.
* If the keyword is marked as mandatory and not present in the template, an IllegalArgumentException is thrown.
*
* @param keyword The keyword to be inserted.
* @param value The value corresponding to the keyword.
* @throws IllegalArgumentException if a mandatory keyword is not present in the template.
*/
private fun insert(keyword: PromptKeyword, value: String) {
if (!templateKeywords.contains(keyword) && keyword.mandatory) {
throw IllegalArgumentException("Prompt template does not contain mandatory ${keyword.name}")
}
insertedKeywordValues[keyword] = value
}

fun insertLanguage(language: String) = apply {
insert(PromptKeyword.LANGUAGE, language)
}

fun insertName(classDisplayName: String) = apply {
if (isPromptValid(PromptKeyword.NAME, prompt)) {
val keyword = "\$${PromptKeyword.NAME.text}"
prompt = prompt.replace(keyword, classDisplayName, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.NAME.text}")
}
insert(PromptKeyword.NAME, classDisplayName)
}

fun insertTestingPlatform(testingPlatformName: String) = apply {
if (isPromptValid(PromptKeyword.TESTING_PLATFORM, prompt)) {
val keyword = "\$${PromptKeyword.TESTING_PLATFORM.text}"
prompt = prompt.replace(keyword, testingPlatformName, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.TESTING_PLATFORM.text}")
}
insert(PromptKeyword.TESTING_PLATFORM, testingPlatformName)
}

fun insertMockingFramework(mockingFrameworkName: String) = apply {
if (isPromptValid(PromptKeyword.MOCKING_FRAMEWORK, prompt)) {
val keyword = "\$${PromptKeyword.MOCKING_FRAMEWORK.text}"
prompt = prompt.replace(keyword, mockingFrameworkName, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.MOCKING_FRAMEWORK.text}")
}
insert(PromptKeyword.MOCKING_FRAMEWORK, mockingFrameworkName)
}

fun insertCodeUnderTest(classFullText: String, classesToTest: List<ClassRepresentation>) = apply {
if (isPromptValid(PromptKeyword.CODE, prompt)) {
val keyword = "\$${PromptKeyword.CODE.text}"
var fullText = "```\n${classFullText}\n```\n"

for (i in 2..classesToTest.size) {
val subClass = classesToTest[i - 2]
val superClass = classesToTest[i - 1]

fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " +
"The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" +
"```\n"
}
prompt = prompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}")
/**
* Inserts the code under test and its related superclass code into the prompt template.
*
* @param codeFullText The full text of the code under test.
* @param classesToTest The list of ClassRepresentation objects representing the classes involved in the code under test.
* @return The modified prompt builder.
*/
fun insertCodeUnderTest(codeFullText: String, classesToTest: List<ClassRepresentation>) = apply {
val fullText = StringBuilder("```\n${codeFullText}\n```\n")

for (i in 2..classesToTest.size) {
val subClass = classesToTest[i - 2]
val superClass = classesToTest[i - 1]

fullText.append("${subClass.qualifiedName} extends ${superClass.qualifiedName}. ")
.append("The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n")
.append("```\n")
}

insert(PromptKeyword.CODE, fullText.toString())
}

fun insertMethodsSignatures(interestingClasses: List<ClassRepresentation>) = apply {
val keyword = "\$${PromptKeyword.METHODS.text}"
val fullText = StringBuilder()

if (isPromptValid(PromptKeyword.METHODS, prompt)) {
var fullText = ""
if (interestingClasses.isNotEmpty()) {
fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n"
}
for (interestingClass in interestingClasses) {
if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) {
continue
}
if (interestingClasses.isNotEmpty()) {
fullText.append("Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n")
}

fullText += "=== methods in ${interestingClass.qualifiedName}:\n"
for (interestingClass in interestingClasses) {
if (interestingClass.qualifiedName.startsWith("java") ||
interestingClass.qualifiedName.startsWith("kotlin")
) {
continue
}

for (method in interestingClass.allMethods) {
// Skip java methods
// TODO: checks for java methods should be done by a caller to make
// this class as abstract and language agnostic as possible.
if (method.containingClassQualifiedName.startsWith("java") ||
method.containingClassQualifiedName.startsWith("kotlin")
) {
continue
}
fullText.append("=== methods in ${interestingClass.qualifiedName}:\n")

fullText += " - ${method.signature}\n"
for (method in interestingClass.allMethods) {
// TODO: checks for java methods should be done by a caller to make
// this class as abstract and language agnostic as possible.
if (method.containingClassQualifiedName.startsWith("java") ||
method.containingClassQualifiedName.startsWith("kotlin")
) {
continue
}

fullText.append(" - ${method.signature}\n")
}
prompt = prompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.METHODS.text}")
}

insert(PromptKeyword.METHODS, fullText.toString())
}

fun insertPolymorphismRelations(
polymorphismRelations: Map<ClassRepresentation, List<ClassRepresentation>>,
) = apply {
val keyword = "\$${PromptKeyword.POLYMORPHISM.text}"
if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) {
// If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable
var fullText = when {
polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable"
else -> ""
}
polymorphismRelations.forEach { entry ->
for (currentSubClass in entry.value) {
val subClassTypeName = when (currentSubClass.classType) {
ClassType.INTERFACE -> "an interface implementing"
ClassType.ABSTRACT_CLASS -> "an abstract sub-class of"
ClassType.CLASS -> "a sub-class of"
ClassType.DATA_CLASS -> "a sub data class of"
ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of"
ClassType.OBJECT -> "a sub object of"
}
fullText += "${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n"
val fullText = StringBuilder()

if (polymorphismRelations.isNotEmpty()) {
fullText.append("Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable.\n\n")
}

for (entry in polymorphismRelations) {
for (currentSubClass in entry.value) {
val subClassTypeName = when (currentSubClass.classType) {
ClassType.INTERFACE -> "an interface implementing"
ClassType.ABSTRACT_CLASS -> "an abstract sub-class of"
ClassType.CLASS -> "a sub-class of"
ClassType.DATA_CLASS -> "a sub data class of"
ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of"
ClassType.OBJECT -> "a sub object of"
}
fullText.append("${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n")
}
prompt = prompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.POLYMORPHISM.text}")
}

insert(PromptKeyword.POLYMORPHISM, fullText.toString())
}

fun insertTestSample(testSamplesCode: String) = apply {
val keyword = "\$${PromptKeyword.TEST_SAMPLE.text}"

if (isPromptValid(PromptKeyword.TEST_SAMPLE, prompt)) {
var fullText = testSamplesCode
if (fullText.isNotBlank()) {
fullText = "Use this test samples:\n$fullText\n"
}
prompt = prompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.TEST_SAMPLE.text}")
var fullText = testSamplesCode
if (fullText.isNotBlank()) {
fullText = "Use this test samples:\n$fullText\n"
}
}

fun build(): String = prompt
insert(PromptKeyword.TEST_SAMPLE, fullText)
}
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
package org.jetbrains.research.testspark.core.generation.llm.prompt

enum class PromptKeyword(val text: String, val description: String, val mandatory: Boolean) {
NAME("NAME", "The name of the code under test (Class name, method name, line number)", true),
CODE("CODE", "The code under test (Class, method, or line)", true),
LANGUAGE("LANGUAGE", "Programming language of the project under test (only Java supported at this point)", true),
enum class PromptKeyword(val description: String, val mandatory: Boolean) {
NAME("The name of the code under test (Class name, method name, line number)", true),
CODE("The code under test (Class, method, or line)", true),
LANGUAGE("Programming language of the project under test (only Java supported at this point)", true),
TESTING_PLATFORM(
"TESTING_PLATFORM",
"Testing platform used in the project (Only JUnit 4 is supported at this point)",
true,
),
MOCKING_FRAMEWORK(
"MOCKING_FRAMEWORK",
"Mock framework that can be used in generated test (Only Mockito is supported at this point)",
false,
),
METHODS("METHODS", "Signature of methods used in the code under tests", false),
POLYMORPHISM("POLYMORPHISM", "Polymorphism relations between classes involved in the code under test", false),
TEST_SAMPLE("TEST_SAMPLE", "Test samples for LLM for test generation", false),
METHODS("Signature of methods used in the code under tests", false),
POLYMORPHISM("Polymorphism relations between classes involved in the code under test", false),
TEST_SAMPLE("Test samples for LLM for test generation", false),
;

fun getOffsets(prompt: String): Pair<Int, Int>? {
val textToHighlight = "\$$text"
val textToHighlight = variable
if (!prompt.contains(textToHighlight)) {
return null
}
Expand All @@ -29,4 +27,13 @@ enum class PromptKeyword(val text: String, val description: String, val mandator
val endOffset = startOffset + textToHighlight.length
return Pair(startOffset, endOffset)
}

/**
* Returns a keyword's text (i.e., its name) with a `$` attached at the start.
*
* Inside a prompt template every keyword is used as `$KEYWORD_NAME`.
* Therefore, this property encapsulates the keyword's representation in a prompt.
*/
val variable: String
get() = "\$${this.name}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ object PromptParserHelper {
fun isPromptValid(prompt: String): Boolean {
PromptKeyword.entries.forEach {
if (it.mandatory) {
val text = "\$${it.text}"
if (!prompt.contains(text)) {
if (!prompt.contains(it.variable)) {
return false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent {
private fun createButtonPanel(keyword: PromptKeyword, panel: JPanel): JPanel {
val buttonPanel = JPanel(FlowLayout(FlowLayout.LEFT))
val editorTextField = panel.getComponent(1) as EditorTextField
val button = JButton("\$${keyword.text}")
button.setForeground(JBColor.ORANGE)
val button = JButton(keyword.variable)
button.foreground = JBColor.ORANGE
button.font = Font("Monochrome", Font.BOLD, 12)

// add actionListener for button
Expand All @@ -340,7 +340,7 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent {
val offset = e.caretModel.offset
val document = editorTextField.document
WriteCommandAction.runWriteCommandAction(e.project) {
document.insertString(offset, "\$${keyword.text}")
document.insertString(offset, keyword.variable)
}
}
}
Expand Down
Loading

0 comments on commit 3f0a1b3

Please sign in to comment.