diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java b/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java index 2889ea5eb24..4d388e3062c 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java @@ -218,7 +218,7 @@ private void readObject(ObjectInputStream stream) throws IOException, ClassNotFo stream.defaultReadObject(); int size = stream.readInt(); - list = new ArrayList<>(size); + list = new ArrayList<>(); for (int i = 0; i < size; i++) { list.add((T) stream.readObject()); } diff --git a/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java b/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java index 75dfbc26e42..ff28edf85ba 100644 --- a/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java +++ b/core/src/test/java/com/datastax/oss/driver/api/core/data/CqlVectorTest.java @@ -17,16 +17,22 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; import com.datastax.oss.driver.api.core.type.codec.TypeCodecs; import com.datastax.oss.driver.internal.SerializationHelper; import com.datastax.oss.driver.shaded.guava.common.collect.Iterators; +import java.io.ByteArrayInputStream; +import java.io.ObjectInputStream; +import java.io.ObjectStreamException; import java.util.AbstractList; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; +import org.apache.commons.codec.DecoderException; +import org.apache.commons.codec.binary.Hex; import org.assertj.core.util.Lists; import org.junit.Test; @@ -231,4 +237,20 @@ public int size() { CqlVector deserialized = SerializationHelper.serializeAndDeserialize(initial); assertThat(deserialized).isEqualTo(initial); } + + @Test + public void should_not_use_preallocate_serialized_size() throws DecoderException { + // serialized CqlVector(1.0f, 2.5f, 3.0f) with size field adjusted to Integer.MAX_VALUE + byte[] suspiciousBytes = + Hex.decodeHex( + "aced000573720042636f6d2e64617461737461782e6f73732e6472697665722e6170692e636f72652e646174612e43716c566563746f722453657269616c697a6174696f6e50726f78790000000000000001030000787077047fffffff7372000f6a6176612e6c616e672e466c6f6174daedc9a2db3cf0ec02000146000576616c7565787200106a6176612e6c616e672e4e756d62657286ac951d0b94e08b02000078703f8000007371007e0002402000007371007e00024040000078" + .toCharArray()); + try { + new ObjectInputStream(new ByteArrayInputStream(suspiciousBytes)).readObject(); + fail("Should not be able to deserialize bytes with incorrect size field"); + } catch (Exception e) { + // check we fail to deserialize, rather than OOM + assertThat(e).isInstanceOf(ObjectStreamException.class); + } + } }