Skip to content

Commit 0b94a20

Browse files
committed
Changed HalfVector to use std::float16_t when available
1 parent 30d3478 commit 0b94a20

File tree

5 files changed

+37
-22
lines changed

5 files changed

+37
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.3.0 (unreleased)
22

33
- Added support for libpqxx 8
4+
- Changed `HalfVector` to use `std::float16_t` when available
45
- Dropped support for libpqxx 7
56
- Dropped support for C++17
67

include/pgvector/halfvec.hpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,42 @@
1212
#include <utility>
1313
#include <vector>
1414

15+
#if __STDCPP_FLOAT16_T__
16+
#include <stdfloat>
17+
#endif
18+
1519
namespace pgvector {
20+
21+
#if __STDCPP_FLOAT16_T__
22+
using HalfType = std::float16_t;
23+
#else
24+
using HalfType = float;
25+
#endif
26+
1627
/// A half vector.
1728
class HalfVector {
1829
public:
19-
/// Creates a half vector from a `std::vector<float>`.
20-
explicit HalfVector(const std::vector<float>& value) : value_{value} {}
30+
/// Creates a half vector from a `std::vector<pgvector::HalfType>`.
31+
explicit HalfVector(const std::vector<HalfType>& value) : value_{value} {}
2132

22-
/// Creates a half vector from a `std::vector<float>`.
23-
explicit HalfVector(std::vector<float>&& value) : value_{std::move(value)} {}
33+
/// Creates a half vector from a `std::vector<pgvector::HalfType>`.
34+
explicit HalfVector(std::vector<HalfType>&& value) : value_{std::move(value)} {}
2435

2536
/// Creates a half vector from a span.
26-
explicit HalfVector(std::span<const float> value) : value_{std::vector<float>(value.begin(), value.end())} {}
37+
explicit HalfVector(std::span<const HalfType> value) : value_{std::vector<HalfType>(value.begin(), value.end())} {}
2738

2839
/// Returns the number of dimensions.
2940
size_t dimensions() const {
3041
return value_.size();
3142
}
3243

33-
/// Returns the half vector as a `std::vector<float>`.
34-
operator const std::vector<float>() const {
44+
/// Returns the half vector as a `std::vector<pgvector::HalfType>`.
45+
operator const std::vector<HalfType>() const {
3546
return value_;
3647
}
3748

38-
/// Returns the half vector as a `std::span<const float>`.
39-
operator const std::span<const float>() const {
49+
/// Returns the half vector as a `std::span<const pgvector::HalfType>`.
50+
operator const std::span<const HalfType>() const {
4051
return value_;
4152
}
4253

@@ -57,7 +68,6 @@ class HalfVector {
5768
}
5869

5970
private:
60-
// TODO use std::float16_t for C++23
61-
std::vector<float> value_;
71+
std::vector<HalfType> value_;
6272
};
6373
} // namespace pgvector

include/pgvector/pqxx.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,23 @@ template <> struct string_traits<pgvector::HalfVector> {
9393
throw conversion_error("Malformed halfvec literal");
9494
}
9595

96-
std::vector<float> values;
96+
std::vector<pgvector::HalfType> values;
9797
if (text.size() > 2) {
9898
std::string_view inner = text.substr(1, text.size() - 2);
9999
size_t start = 0;
100100
for (size_t i = 0; i < inner.size(); i++) {
101101
if (inner[i] == ',') {
102-
values.push_back(string_traits<float>::from_string(inner.substr(start, i - start), c));
102+
values.push_back(static_cast<pgvector::HalfType>(string_traits<float>::from_string(inner.substr(start, i - start), c)));
103103
start = i + 1;
104104
}
105105
}
106-
values.push_back(string_traits<float>::from_string(inner.substr(start), c));
106+
values.push_back(static_cast<pgvector::HalfType>(string_traits<float>::from_string(inner.substr(start), c)));
107107
}
108108
return pgvector::HalfVector(std::move(values));
109109
}
110110

111111
static std::string_view to_buf(std::span<char> buf, const pgvector::HalfVector& value, ctx c = {}) {
112-
std::span<const float> values{value};
112+
std::span<const pgvector::HalfType> values{value};
113113

114114
// important! size_buffer cannot throw an exception on overflow
115115
// so perform this check before writing any data
@@ -124,7 +124,7 @@ template <> struct string_traits<pgvector::HalfVector> {
124124
if (i != 0) {
125125
buf[here++] = ',';
126126
}
127-
here += pqxx::into_buf(buf.subspan(here), values[i], c);
127+
here += pqxx::into_buf(buf.subspan(here), static_cast<float>(values[i]), c);
128128
}
129129

130130
buf[here++] = ']';
@@ -133,15 +133,15 @@ template <> struct string_traits<pgvector::HalfVector> {
133133
}
134134

135135
static size_t size_buffer(const pgvector::HalfVector& value) noexcept {
136-
std::span<const float> values{value};
136+
std::span<const pgvector::HalfType> values{value};
137137

138138
// cannot throw an exception here on overflow
139139
// so throw in into_buf
140140

141141
size_t size = 2; // [ and ]
142142
for (const auto v : values) {
143143
size += 1; // ,
144-
size += string_traits<float>::size_buffer(v);
144+
size += string_traits<float>::size_buffer(static_cast<float>(v));
145145
}
146146
return size;
147147
}

test/halfvec_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ static void test_constructor_vector() {
1212
}
1313

1414
static void test_constructor_span() {
15-
auto vec = HalfVector(std::span<const float>({1, 2, 3}));
15+
auto vec = HalfVector(std::span<const pgvector::HalfType>({1, 2, 3}));
1616
assert_equal(vec.dimensions(), 3u);
1717
}
1818

test/pqxx_test.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ void test_halfvec(pqxx::connection &conn) {
4949

5050
pqxx::nontransaction tx(conn);
5151
auto embedding = pgvector::HalfVector({1, 2, 3});
52-
float arr[] = {4, 5, 6};
52+
pgvector::HalfType arr[] = {4, 5, 6};
5353
auto embedding2 = pgvector::HalfVector(std::span{arr, 3});
5454
tx.exec("INSERT INTO items (half_embedding) VALUES ($1), ($2), ($3)", {embedding, embedding2, std::nullopt});
5555

@@ -179,16 +179,20 @@ void test_vector_from_string() {
179179

180180
void test_halfvec_to_string() {
181181
assert_equal(pqxx::to_string(pgvector::HalfVector({1, 2, 3})), "[1,2,3]");
182+
#if __STDCPP_FLOAT16_T__
183+
assert_equal(pqxx::to_string(pgvector::HalfVector({-1.234567890123f16})), "[-1.234375]");
184+
#else
182185
assert_equal(pqxx::to_string(pgvector::HalfVector({-1.234567890123})), "[-1.2345679]");
186+
#endif
183187

184188
assert_exception<pqxx::conversion_overrun>([] {
185-
pqxx::to_string(pgvector::HalfVector(std::vector<float>(16001)));
189+
pqxx::to_string(pgvector::HalfVector(std::vector<pgvector::HalfType>(16001)));
186190
}, "halfvec cannot have more than 16000 dimensions");
187191
}
188192

189193
void test_halfvec_from_string() {
190194
assert_equal(pqxx::from_string<pgvector::HalfVector>("[1,2,3]"), pgvector::HalfVector({1, 2, 3}));
191-
assert_equal(pqxx::from_string<pgvector::HalfVector>("[]"), pgvector::HalfVector(std::vector<float>{}));
195+
assert_equal(pqxx::from_string<pgvector::HalfVector>("[]"), pgvector::HalfVector(std::vector<pgvector::HalfType>{}));
192196

193197
assert_exception<pqxx::conversion_error>([] {
194198
auto _ = pqxx::from_string<pgvector::HalfVector>("");

0 commit comments

Comments
 (0)