Skip to content

Commit

Permalink
Added support for bit type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 23, 2024
1 parent 5e18949 commit b7bb910
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## 0.1.5 (unreleased)

- Added support for `halfvec` and `sparsevec` types
- Added support for `halfvec`, `bit`, and `sparsevec` types

## 0.1.4 (2023-12-08)

Expand Down
124 changes: 124 additions & 0 deletions src/main/java/com/pgvector/PGbit.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package com.pgvector;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.postgresql.PGConnection;
import org.postgresql.util.ByteConverter;
import org.postgresql.util.PGBinaryObject;
import org.postgresql.util.PGobject;

/**
* PGbit class
*/
public class PGbit extends PGobject implements PGBinaryObject, Serializable, Cloneable {
private int length;
private byte[] data;

/**
* Constructor
*/
public PGbit() {
type = "bit";
}

/**
* Constructor
*
* @param v boolean array
*/
public PGbit(boolean[] v) {
this();
length = v.length;
data = new byte[(length + 7) / 8];
for (int i = 0; i < length; i++) {
data[i / 8] |= (v[i] ? 1 : 0) << (7 - (i % 8));
}
}

/**
* Constructor
*
* @param s text representation of a bit string
* @throws SQLException exception
*/
public PGbit(String s) throws SQLException {
this();
setValue(s);
}

/**
* Sets the value from a text representation of a bit string
*/
public void setValue(String s) throws SQLException {
if (s == null) {
data = null;
} else {
length = s.length();
data = new byte[(length + 7) / 8];
for (int i = 0; i < length; i++) {
data[i / 8] |= (s.charAt(i) != '0' ? 1 : 0) << (7 - (i % 8));
}
}
}

/**
* Returns the text representation of a bit string
*/
public String getValue() {
if (data == null) {
return null;
} else {
StringBuilder sb = new StringBuilder(length);
for (int i = 0; i < length; i++) {
sb.append(((data[i / 8] >> (7 - (i % 8))) & 1) == 1 ? '1' : '0');
}
return sb.toString();
}
}

/**
* Returns the number of bytes for the binary representation
*/
public int lengthInBytes() {
return data == null ? 0 : 4 + data.length;
}

/**
* Sets the value from a binary representation of a bit string
*/
public void setByteValue(byte[] value, int offset) throws SQLException {
length = ByteConverter.int4(value, offset);
data = new byte[(length + 7) / 8];
for (int i = 0; i < data.length; i++) {
data[i] = value[offset + 4 + i];
}
}

/**
* Writes the binary representation of a bit string
*/
public void toBytes(byte[] bytes, int offset) {
if (data == null) {
return;
}

ByteConverter.int4(bytes, offset, length);
for (int i = 0; i < data.length; i++) {
bytes[offset + 4 + i] = data[i];
}
}

/**
* Registers the bit type
*
* @param conn connection
* @throws SQLException exception
*/
public static void addBitType(Connection conn) throws SQLException {
conn.unwrap(PGConnection.class).addDataType("bit", PGbit.class);
}
}
54 changes: 54 additions & 0 deletions src/test/java/com/pgvector/JDBCJavaTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;

public class JDBCJavaTest {
Expand Down Expand Up @@ -117,6 +118,59 @@ void halfvecExample(boolean readBinary) throws SQLException {
assertNull(embeddings.get(3));
}

@Test
void testBitReadText() throws SQLException {
bitExample(false);
}

@Test
void testBitReadBinary() throws SQLException {
bitExample(true);
}

void bitExample(boolean readBinary) throws SQLException {
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test");
if (readBinary) {
conn.unwrap(PGConnection.class).setPrepareThreshold(-1);
}

Statement setupStmt = conn.createStatement();
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items");

PGbit.addBitType(conn);

Statement createStmt = conn.createStatement();
createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding bit(9))");

PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)");
insertStmt.setObject(1, new PGbit(new boolean[] {false, false, false, false, false, false, false, false, false}));
insertStmt.setObject(2, new PGbit(new boolean[] {false, true, false, true, false, false, false, false, true}));
insertStmt.setObject(3, new PGbit(new boolean[] {false, true, true, true, false, false, false, false, true}));
insertStmt.setObject(4, null);
insertStmt.executeUpdate();

PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <~> ? LIMIT 5");
neighborStmt.setObject(1, new PGbit(new boolean[] {false, true, false, true, false, false, false, false, true}));
ResultSet rs = neighborStmt.executeQuery();
List<Long> ids = new ArrayList<>();
List<PGbit> embeddings = new ArrayList<>();
while (rs.next()) {
ids.add(rs.getLong("id"));
embeddings.add((PGbit) rs.getObject("embedding"));
}
assertArrayEquals(new Long[] {2L, 3L, 1L, 4L}, ids.toArray());
assertEquals("010100001", embeddings.get(0).getValue());
assertEquals("011100001", embeddings.get(1).getValue());
assertEquals("000000000", embeddings.get(2).getValue());
assertNull(embeddings.get(3));

Statement indexStmt = conn.createStatement();
indexStmt.executeUpdate("CREATE INDEX ON jdbc_items USING ivfflat (embedding bit_hamming_ops) WITH (lists = 100)");

conn.close();
}

@Test
void testSparsevecReadText() throws SQLException {
sparsevecExample(false);
Expand Down
23 changes: 23 additions & 0 deletions src/test/java/com/pgvector/PGbitTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.pgvector;

import java.sql.SQLException;
import java.util.Arrays;
import com.pgvector.PGbit;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class PGbitTest {
@Test
void testArrayConstructor() {
PGbit vec = new PGbit(new boolean[] {true, false, true});
assertEquals("101", vec.getValue());
}

@Test
void testStringConstructor() throws SQLException {
PGbit vec = new PGbit("101");
assertEquals("101", vec.getValue());
}
}

0 comments on commit b7bb910

Please sign in to comment.