Skip to content

Commit

Permalink
Added support for halfvec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 19, 2024
1 parent 45d8b38 commit 54b8f98
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.1.5 (unreleased)

- Added support for `halfvec` type

## 0.1.4 (2023-12-08)

- Added `List` constructor
Expand Down
109 changes: 109 additions & 0 deletions src/main/java/com/pgvector/PGhalfvec.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package com.pgvector;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.postgresql.PGConnection;
import org.postgresql.util.PGobject;

/**
* PGhalfvec class
*/
public class PGhalfvec extends PGobject implements Serializable, Cloneable {
private float[] vec;

/**
* Constructor
*/
public PGhalfvec() {
type = "halfvec";
}

/**
* Constructor
*
* @param v float array
*/
public PGhalfvec(float[] v) {
this();
vec = v;
}

/**
* Constructor
*
* @param <T> number
* @param v list of numbers
*/
public <T extends Number> PGhalfvec(List<T> v) {
this();
if (Objects.isNull(v)) {
vec = null;
} else {
vec = new float[v.size()];
int i = 0;
for (T f : v) {
vec[i++] = f.floatValue();
}
}
}

/**
* Constructor
*
* @param s text representation of a half vector
* @throws SQLException exception
*/
public PGhalfvec(String s) throws SQLException {
this();
setValue(s);
}

/**
* Sets the value from a text representation of a half vector
*/
public void setValue(String s) throws SQLException {
if (s == null) {
vec = null;
} else {
String[] sp = s.substring(1, s.length() - 1).split(",");
vec = new float[sp.length];
for (int i = 0; i < sp.length; i++) {
vec[i] = Float.parseFloat(sp[i]);
}
}
}

/**
* Returns the text representation of a half vector
*/
public String getValue() {
if (vec == null) {
return null;
} else {
return Arrays.toString(vec).replace(" ", "");
}
}

/**
* Returns an array
*
* @return an array
*/
public float[] toArray() {
return vec;
}

/**
* Registers the halfvec type
*
* @param conn connection
* @throws SQLException exception
*/
public static void addHalfvecType(Connection conn) throws SQLException {
conn.unwrap(PGConnection.class).addDataType("halfvec", PGhalfvec.class);
}
}
36 changes: 36 additions & 0 deletions src/test/java/com/pgvector/JDBCJavaTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,40 @@ void example(boolean readBinary) throws SQLException {

conn.close();
}

@Test
void testHalfvec() throws SQLException {
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test");

Statement setupStmt = conn.createStatement();
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items");

PGhalfvec.addHalfvecType(conn);

Statement createStmt = conn.createStatement();
createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding halfvec(3))");

PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)");
insertStmt.setObject(1, new PGhalfvec(new float[] {1, 1, 1}));
insertStmt.setObject(2, new PGhalfvec(new float[] {2, 2, 2}));
insertStmt.setObject(3, new PGhalfvec(new float[] {1, 1, 2}));
insertStmt.setObject(4, null);
insertStmt.executeUpdate();

PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <-> ? LIMIT 5");
neighborStmt.setObject(1, new PGhalfvec(new float[] {1, 1, 1}));
ResultSet rs = neighborStmt.executeQuery();
List<Long> ids = new ArrayList<>();
List<PGhalfvec> embeddings = new ArrayList<>();
while (rs.next()) {
ids.add(rs.getLong("id"));
embeddings.add((PGhalfvec) rs.getObject("embedding"));
}
assertArrayEquals(new Long[] {1L, 3L, 2L, 4L}, ids.toArray());
assertArrayEquals(new float[] {1, 1, 1}, embeddings.get(0).toArray());
assertArrayEquals(new float[] {1, 1, 2}, embeddings.get(1).toArray());
assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray());
assertNull(embeddings.get(3));
}
}
43 changes: 43 additions & 0 deletions src/test/java/com/pgvector/PGhalfvecTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.pgvector;

import java.sql.SQLException;
import java.util.Arrays;
import com.pgvector.PGhalfvec;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class PGhalfvecTest {
@Test
void testArrayConstructor() {
PGhalfvec vec = new PGhalfvec(new float[] {1, 2, 3});
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
}

@Test
void testStringConstructor() throws SQLException {
PGhalfvec vec = new PGhalfvec("[1,2,3]");
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
}

@Test
void testFloatListConstructor() {
Float[] a = new Float[] {Float.valueOf(1), Float.valueOf(2), Float.valueOf(3)};
PGhalfvec vec = new PGhalfvec(Arrays.asList(a));
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
}

@Test
void testDoubleListConstructor() {
Double[] a = new Double[] {Double.valueOf(1), Double.valueOf(2), Double.valueOf(3)};
PGhalfvec vec = new PGhalfvec(Arrays.asList(a));
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
}

@Test
void testGetValue() {
PGhalfvec vec = new PGhalfvec(new float[] {1, 2, 3});
assertEquals("[1.0,2.0,3.0]", vec.getValue());
}
}

0 comments on commit 54b8f98

Please sign in to comment.