From 54b8f98c0a95d610ba0ae3071cd7e9bb06d9815a Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 18 May 2024 22:05:57 -0400 Subject: [PATCH] Added support for halfvec type --- CHANGELOG.md | 4 + src/main/java/com/pgvector/PGhalfvec.java | 109 ++++++++++++++++++ src/test/java/com/pgvector/JDBCJavaTest.java | 36 ++++++ src/test/java/com/pgvector/PGhalfvecTest.java | 43 +++++++ 4 files changed, 192 insertions(+) create mode 100644 src/main/java/com/pgvector/PGhalfvec.java create mode 100644 src/test/java/com/pgvector/PGhalfvecTest.java diff --git a/CHANGELOG.md b/CHANGELOG.md index b840f48..ed5a202 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.1.5 (unreleased) + +- Added support for `halfvec` type + ## 0.1.4 (2023-12-08) - Added `List` constructor diff --git a/src/main/java/com/pgvector/PGhalfvec.java b/src/main/java/com/pgvector/PGhalfvec.java new file mode 100644 index 0000000..23aed13 --- /dev/null +++ b/src/main/java/com/pgvector/PGhalfvec.java @@ -0,0 +1,109 @@ +package com.pgvector; + +import java.io.Serializable; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import org.postgresql.PGConnection; +import org.postgresql.util.PGobject; + +/** + * PGhalfvec class + */ +public class PGhalfvec extends PGobject implements Serializable, Cloneable { + private float[] vec; + + /** + * Constructor + */ + public PGhalfvec() { + type = "halfvec"; + } + + /** + * Constructor + * + * @param v float array + */ + public PGhalfvec(float[] v) { + this(); + vec = v; + } + + /** + * Constructor + * + * @param number + * @param v list of numbers + */ + public PGhalfvec(List v) { + this(); + if (Objects.isNull(v)) { + vec = null; + } else { + vec = new float[v.size()]; + int i = 0; + for (T f : v) { + vec[i++] = f.floatValue(); + } + } + } + + /** + * Constructor + * + * @param s text representation of a half vector + * @throws SQLException exception + */ + public PGhalfvec(String s) throws SQLException { + this(); + setValue(s); + } + + /** + * Sets the value from a text representation of a half vector + */ + public void setValue(String s) throws SQLException { + if (s == null) { + vec = null; + } else { + String[] sp = s.substring(1, s.length() - 1).split(","); + vec = new float[sp.length]; + for (int i = 0; i < sp.length; i++) { + vec[i] = Float.parseFloat(sp[i]); + } + } + } + + /** + * Returns the text representation of a half vector + */ + public String getValue() { + if (vec == null) { + return null; + } else { + return Arrays.toString(vec).replace(" ", ""); + } + } + + /** + * Returns an array + * + * @return an array + */ + public float[] toArray() { + return vec; + } + + /** + * Registers the halfvec type + * + * @param conn connection + * @throws SQLException exception + */ + public static void addHalfvecType(Connection conn) throws SQLException { + conn.unwrap(PGConnection.class).addDataType("halfvec", PGhalfvec.class); + } +} diff --git a/src/test/java/com/pgvector/JDBCJavaTest.java b/src/test/java/com/pgvector/JDBCJavaTest.java index b9b0948..018713d 100644 --- a/src/test/java/com/pgvector/JDBCJavaTest.java +++ b/src/test/java/com/pgvector/JDBCJavaTest.java @@ -68,4 +68,40 @@ void example(boolean readBinary) throws SQLException { conn.close(); } + + @Test + void testHalfvec() throws SQLException { + Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test"); + + Statement setupStmt = conn.createStatement(); + setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector"); + setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items"); + + PGhalfvec.addHalfvecType(conn); + + Statement createStmt = conn.createStatement(); + createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding halfvec(3))"); + + PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)"); + insertStmt.setObject(1, new PGhalfvec(new float[] {1, 1, 1})); + insertStmt.setObject(2, new PGhalfvec(new float[] {2, 2, 2})); + insertStmt.setObject(3, new PGhalfvec(new float[] {1, 1, 2})); + insertStmt.setObject(4, null); + insertStmt.executeUpdate(); + + PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <-> ? LIMIT 5"); + neighborStmt.setObject(1, new PGhalfvec(new float[] {1, 1, 1})); + ResultSet rs = neighborStmt.executeQuery(); + List ids = new ArrayList<>(); + List embeddings = new ArrayList<>(); + while (rs.next()) { + ids.add(rs.getLong("id")); + embeddings.add((PGhalfvec) rs.getObject("embedding")); + } + assertArrayEquals(new Long[] {1L, 3L, 2L, 4L}, ids.toArray()); + assertArrayEquals(new float[] {1, 1, 1}, embeddings.get(0).toArray()); + assertArrayEquals(new float[] {1, 1, 2}, embeddings.get(1).toArray()); + assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray()); + assertNull(embeddings.get(3)); + } } diff --git a/src/test/java/com/pgvector/PGhalfvecTest.java b/src/test/java/com/pgvector/PGhalfvecTest.java new file mode 100644 index 0000000..823967f --- /dev/null +++ b/src/test/java/com/pgvector/PGhalfvecTest.java @@ -0,0 +1,43 @@ +package com.pgvector; + +import java.sql.SQLException; +import java.util.Arrays; +import com.pgvector.PGhalfvec; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class PGhalfvecTest { + @Test + void testArrayConstructor() { + PGhalfvec vec = new PGhalfvec(new float[] {1, 2, 3}); + assertArrayEquals(new float[] {1, 2, 3}, vec.toArray()); + } + + @Test + void testStringConstructor() throws SQLException { + PGhalfvec vec = new PGhalfvec("[1,2,3]"); + assertArrayEquals(new float[] {1, 2, 3}, vec.toArray()); + } + + @Test + void testFloatListConstructor() { + Float[] a = new Float[] {Float.valueOf(1), Float.valueOf(2), Float.valueOf(3)}; + PGhalfvec vec = new PGhalfvec(Arrays.asList(a)); + assertArrayEquals(new float[] {1, 2, 3}, vec.toArray()); + } + + @Test + void testDoubleListConstructor() { + Double[] a = new Double[] {Double.valueOf(1), Double.valueOf(2), Double.valueOf(3)}; + PGhalfvec vec = new PGhalfvec(Arrays.asList(a)); + assertArrayEquals(new float[] {1, 2, 3}, vec.toArray()); + } + + @Test + void testGetValue() { + PGhalfvec vec = new PGhalfvec(new float[] {1, 2, 3}); + assertEquals("[1.0,2.0,3.0]", vec.getValue()); + } +}