Skip to content

Commit

Permalink
vector db test
Browse files Browse the repository at this point in the history
  • Loading branch information
shultseva committed May 20, 2024
1 parent 6c8ed0f commit a1fc789
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,39 @@

public class ScoreMetrics {

private static final Histogram scoreHistogram = new Histogram(100, 3);
private String name;

public static void set(int score) {
private final Histogram scoreHistogram = new Histogram(100, 3);

public ScoreMetrics() {
}

public void set(int score) {
scoreHistogram.recordValue(score);
}

public static long getMin() {
public long getMin() {
return scoreHistogram.getMinValue();
}

public static long getMax() {
public long getMax() {
return scoreHistogram.getMaxValue();
}

public static double getMean() {
public double getMean() {
return scoreHistogram.getMean();
}

public static long getPercentLowerThen(int value) {
public double getPercentile(double value) {
return scoreHistogram.getValueAtPercentile(value);
}

public long getPercentLowerThen(int value) {
var lower = scoreHistogram.getCountBetweenValues(0, value);
return (lower * 100) / scoreHistogram.getTotalCount();
}

public void setName(String name) {
this.name = name;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
import java.util.Queue;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;

import static java.util.concurrent.TimeUnit.MILLISECONDS;

public class VectorCollectionSearchDatasetTest extends HazelcastTest {

public String name;

public String datasetUrl;

public String workingDirectory;
Expand All @@ -58,35 +61,48 @@ public class VectorCollectionSearchDatasetTest extends HazelcastTest {

// inner test parameters

private static final String collectionName = "performance-collection";

private static final int PUT_BATCH_SIZE = 2_000;

private static final int MAX_PUT_ALL_IN_FLIGHT = 5;

private VectorCollection<Integer, Integer> collection;

private final AtomicInteger searchCounter = new AtomicInteger(0);

private final AtomicInteger putCounter = new AtomicInteger(0);

private TestDataset testDataset;

private final Queue<TestSearchResult> searchResults = new ConcurrentLinkedQueue<>();

private final ScoreMetrics scoreMetrics = new ScoreMetrics();

private long indexBuildTime = 0;

private final ReentrantLock lock = new ReentrantLock();

private final CountDownLatch setupDone = new CountDownLatch(1);

@Setup
public void setup() {
if (!lock.tryLock()) {
try {
setupDone.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return;
}
// only one thread perform the setup
scoreMetrics.setName(name);
DatasetReader reader = DatasetReader.create(datasetUrl, workingDirectory, normalize);
logger.info("Use normalize: {}", normalize);
var size = Math.min(reader.getSize(), loadFirst);
int dimension = reader.getDimension();
assert dimension == reader.getTestDatasetDimension() : "dataset dimension does not correspond to query vector dimension";
testDataset = reader.getTestDataset();
numberOfSearchIterations = Math.min(numberOfSearchIterations, testDataset.size());

logger.info("Vector collection name: {}", name);
logger.info("Use normalize: {}", normalize);
collection = VectorCollection.getCollection(
targetInstance,
new VectorCollectionConfig(collectionName)
new VectorCollectionConfig(name)
.addVectorIndexConfig(
new VectorIndexConfig()
.setMetric(Metric.valueOf(metric))
Expand All @@ -96,23 +112,24 @@ public void setup() {
)
);

var start = System.currentTimeMillis();
var indexBuildTimeStart = System.currentTimeMillis();

Map<Integer, VectorDocument<Integer>> buffer = new HashMap<>();
int index;
Pipelining<Void> pipelining = new Pipelining<>(MAX_PUT_ALL_IN_FLIGHT);
logger.info("Start loading data...");

while ((index = putCounter.getAndIncrement()) < size) {
int index = 0;
while (index < size) {
buffer.put(index, VectorDocument.of(index, VectorValues.of(reader.getTrainVector(index))));
index++;
if (buffer.size() % PUT_BATCH_SIZE == 0) {
addToPipelineWithLogging(pipelining, collection.putAllAsync(buffer));
logger.info(
"Uploaded {} vectors from {}. Block size: {}. Total time (min): {}",
index,
size,
buffer.size(),
MILLISECONDS.toMinutes(System.currentTimeMillis() - start)
MILLISECONDS.toMinutes(System.currentTimeMillis() - indexBuildTimeStart)
);
buffer = new HashMap<>();
}
Expand All @@ -124,26 +141,26 @@ public void setup() {
}

logger.info("Start waiting pipeline results...");
var pipelineWaiting = withTimer(() -> {
try {
pipelining.results();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
logger.info("Pipeline waiting finished in {} min", MILLISECONDS.toMinutes(pipelineWaiting));
try {
pipelining.results();
} catch (Exception e) {
throw new RuntimeException(e);
}

var cleanupTimer = withTimer(() -> collection.optimizeAsync().toCompletableFuture().join());
indexBuildTime = System.currentTimeMillis() - indexBuildTimeStart;

logger.info("Collection size: {}", size);
logger.info("Collection dimension: {}", reader.getDimension());
logger.info("Cleanup time (min): {}", MILLISECONDS.toMinutes(cleanupTimer));
logger.info("Index build time (min): {}", MILLISECONDS.toMinutes(System.currentTimeMillis() - start));
logger.info("Index build time (min): {}", MILLISECONDS.toMinutes(indexBuildTime));

setupDone.countDown();
}

@TimeStep()
public void search(ThreadState state) {
var iteration = searchCounter.getAndIncrement();
var iteration = state.getAndIncrementIteration();
if (iteration >= numberOfSearchIterations) {
testContext.stop();
return;
Expand All @@ -160,24 +177,36 @@ public void afterRun() {
int index = testSearchResult.index();
List<Integer> ids = new ArrayList<>();
VectorUtils.forEach(testSearchResult.results, r -> ids.add((Integer) r.getKey()));
ScoreMetrics.set((int) (testDataset.getPrecisionV1(ids, index, limit) * 100));
scoreMetrics.set((int) (testDataset.getPrecision(ids, index, limit) * 100));
});

writePureResultsToFile("precision.out");
logger.info("Number of search iteration: {}", searchCounter.get());
logger.info("Min score: {}", ScoreMetrics.getMin());
logger.info("Max score: {}", ScoreMetrics.getMax());
logger.info("Mean score: {}", ScoreMetrics.getMean());
logger.info("Percent of results lower then 98% precision: {}", ScoreMetrics.getPercentLowerThen(98));
writeAllSearchResultsToFile("precision_" + name + ".out");
appendStatisticsToFile("statistics.out");
logger.info("Results for {}", name);
logger.info("Min score: {}", scoreMetrics.getMin());
logger.info("Max score: {}", scoreMetrics.getMax());
logger.info("Mean score: {}", scoreMetrics.getMean());
logger.info("5pt: {}", scoreMetrics.getPercentile(5));
logger.info("10pt: {}", scoreMetrics.getPercentile(10));
logger.info("The percentage of results with precision lower than 98%: {}", scoreMetrics.getPercentLowerThen(98));
logger.info("The percentage of results with precision lower than 99%: {}", scoreMetrics.getPercentLowerThen(99));
}

public static class ThreadState extends BaseThreadState {

private int iteration = 0;

public int getAndIncrementIteration() {
var it = iteration;
iteration++;
return it;
}
}

public record TestSearchResult(int index, float[] searchVector, SearchResults results) {
}

private void writePureResultsToFile(String fileName) {
private void writeAllSearchResultsToFile(String fileName) {
try {
Function<Float, Float> restore = VectorUtils.restoreRealMetric(Metric.valueOf(metric));
var fileWriter = new FileWriter(fileName);
Expand All @@ -203,6 +232,22 @@ private void writePureResultsToFile(String fileName) {
}
}

private void appendStatisticsToFile(String fileName) {
try {
FileWriter fileWriter = new FileWriter(fileName, true);
PrintWriter printWriter = new PrintWriter(fileWriter);
List<String> values = List.of(
name,
String.valueOf(indexBuildTime),
String.valueOf(scoreMetrics.getMean())
);
printWriter.println(String.join(", ", values));
printWriter.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

void addToPipelineWithLogging(Pipelining<Void> pipelining, CompletionStage<Void> asyncInvocation) {
var now = System.currentTimeMillis();
try {
Expand All @@ -228,11 +273,9 @@ private long withTimer(Runnable runnable) {

private float getFirstCoordinate(VectorValues vectorValues) {
var v = (VectorValues.SingleVectorValues) vectorValues;
if(v == null || v.vector().length == 0) {
if (v == null || v.vector().length == 0) {
return 0;
}
return v.vector()[0];
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public static void forEach(SearchResults searchResults, Consumer<SearchResult> c

public static Function<Float, Float> restoreRealMetric(Metric metric) {
return switch (metric) {
case COSINE -> jMetric -> 2 * jMetric - 1;
case COSINE, DOT -> jMetric -> 2 * jMetric - 1;
default -> jMetric -> -1f;
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,10 @@ public int size() {
return searchVectors.length;
}

public float getPrecisionV1(List<Integer> actualVectorsIds, int index, int top) {
public float getPrecision(List<Integer> actualVectorsIds, int index, int top) {
var actualSet = new HashSet<>(actualVectorsIds);
var expectedSet = Arrays.stream(Arrays.copyOfRange(closestIds[index], 0, top)).boxed().collect(Collectors.toSet());
actualSet.retainAll(expectedSet);
return ((float) actualSet.size()) / top;
}

public float getPrecisionV2(float[] actualVectorsScore, int index) {
var expected = Arrays.copyOfRange(closestScores[index], 0, actualVectorsScore.length);
return distance(actualVectorsScore, expected);
}

private float distance(float[] array1, float[] array2) {
double sum = 0f;
for (int i = 0; i < array1.length; i++) {
sum += Math.pow((array1[i] - array2[i]), 2.0);
}
return (float) Math.sqrt(sum);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,64 +8,40 @@

public class TestDatasetDiffblueTest {
@Test
public void testGetPrecisionV1_score100() {
public void testGetPrecision_score100() {
var actual = (
new TestDataset(
new float[][]{new float[]{0f}},
new int[][]{new int[]{1, 2, 3, 4}},
new float[][]{new float[]{0f}}
)
).getPrecisionV1(List.of(1, 2), 0, 2);
).getPrecision(List.of(1, 2), 0, 2);
assertEquals(1, actual, 0.0f);
}

@Test
public void testGetPrecisionV1_score0() {
public void testGetPrecision_score0() {
assertEquals(0.0f,
(
new TestDataset(
new float[][]{new float[]{0f}},
new int[][]{new int[]{1, 2, 3, 4}, new int[]{1, 2, 1, 2}},
new float[][]{new float[]{0f}}
)
).getPrecisionV1(List.of(2), 0, 1),
).getPrecision(List.of(2), 0, 1),
0.0f);
}

@Test
public void testGetPrecisionV1_score50() {
public void testGetPrecision_score50() {
assertEquals(0.5f,
(
new TestDataset(
new float[][]{new float[]{0f}},
new int[][]{new int[]{1, 2, 3, 4}, new int[]{2, 5, 6}},
new float[][]{new float[]{0f}}
)
).getPrecisionV1(List.of(2, 6), 0, 2),
).getPrecision(List.of(2, 6), 0, 2),
0.1f);
}

@Test
public void testGetPrecisionV2_theSameVector() {
var actual = (
new TestDataset(
new float[][]{new float[]{0f}},
new int[][]{new int[]{1, 2, 3, 4}},
new float[][]{new float[]{10.0f, 0.5f, 10.0f, 0.5f}}
)
).getPrecisionV2(new float[]{10.0f, 0.5f, 10.0f, 0.5f}, 0);
assertEquals(0, actual, 0.0f);
}

@Test
public void testGetPrecisionV2_DiffrentVector() {
var actual = (
new TestDataset(
new float[][]{new float[]{0f}},
new int[][]{new int[]{1, 2, 3, 4}},
new float[][]{new float[]{10.0f, 0.5f, 10.0f, 0.5f}}
)
).getPrecisionV2(new float[]{8.0f, 0.4f, 9.0f, 0.5f}, 0);
assertEquals(Math.sqrt(5.01), actual, 0.1f);
}
}

0 comments on commit a1fc789

Please sign in to comment.