Skip to content

Commit

Permalink
Add validation and rename parameter in TokenCountBatchingStrategy
Browse files Browse the repository at this point in the history
This commit rename the thresholdFactor parameter to reservePercentage to better
reflect its purpose in the TokenCountBatchingStrategy class. It also adds
validation for input parameters including maxInputTokenCount > 0 and
reservePercentage between 0 and 1 for safer initialization.

The change ensures proper parameter validation on TokenCountBatchingStrategy
creation to prevent potential runtime errors from invalid inputs.
  • Loading branch information
1993heqiang authored and Mark Pollack committed Nov 6, 2024
1 parent 83f7164 commit 0825d52
Showing 1 changed file with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ public TokenCountBatchingStrategy() {

/**
* @param encodingType {@link EncodingType}
* @param thresholdFactor the threshold factor to use on top of the max input token
* count
* @param reservePercentage the percentage of tokens to reserve from the max input
* token count to create a buffer.
* @param maxInputTokenCount upper limit for input tokens
*/
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double thresholdFactor) {
this(encodingType, maxInputTokenCount, thresholdFactor, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage) {
this(encodingType, maxInputTokenCount, reservePercentage, Document.DEFAULT_CONTENT_FORMATTER,
MetadataMode.NONE);
}

/**
Expand All @@ -99,6 +100,8 @@ public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCo
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage,
ContentFormatter contentFormatter, MetadataMode metadataMode) {
Assert.notNull(encodingType, "EncodingType must not be null");
Assert.isTrue(maxInputTokenCount > 0, "MaxInputTokenCount must be greater than 0");
Assert.isTrue(reservePercentage >= 0 && reservePercentage < 1, "ReservePercentage must be in range [0, 1)");
Assert.notNull(contentFormatter, "ContentFormatter must not be null");
Assert.notNull(metadataMode, "MetadataMode must not be null");
this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType);
Expand All @@ -120,6 +123,10 @@ public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCo
public TokenCountBatchingStrategy(TokenCountEstimator tokenCountEstimator, int maxInputTokenCount,
double reservePercentage, ContentFormatter contentFormatter, MetadataMode metadataMode) {
Assert.notNull(tokenCountEstimator, "TokenCountEstimator must not be null");
Assert.isTrue(maxInputTokenCount > 0, "MaxInputTokenCount must be greater than 0");
Assert.isTrue(reservePercentage >= 0 && reservePercentage < 1, "ReservePercentage must be in range [0, 1)");
Assert.notNull(contentFormatter, "ContentFormatter must not be null");
Assert.notNull(metadataMode, "MetadataMode must not be null");
this.tokenCountEstimator = tokenCountEstimator;
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage));
this.contentFormater = contentFormatter;
Expand Down

0 comments on commit 0825d52

Please sign in to comment.