Skip to content

Commit

Permalink
resolve the issue #171
Browse files Browse the repository at this point in the history
  • Loading branch information
masajiro committed Oct 31, 2024
1 parent 9908df2 commit fae388a
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 33 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.0
2.3.2
2 changes: 2 additions & 0 deletions lib/NGT/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ namespace NGT {

class quint8 {
public:
quint8(){}
quint8(uint8_t v):value(v){}
quint8 &operator=(uint8_t v) { value = v; return *this; }
operator uint8_t() const { return value; }
Expand All @@ -67,6 +68,7 @@ namespace NGT {
};
class qsint8 {
public:
qsint8(){}
qsint8(int8_t v):value(v){}
qsint8 &operator=(int8_t v) { value = v; return *this; }
operator int8_t() const { return value; }
Expand Down
36 changes: 28 additions & 8 deletions lib/NGT/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,17 @@ NGT::Index::createGraphAndTree(const string &database, NGT::Property &prop, cons
StdOstreamRedirector redirector(redirect);
redirector.begin();
try {
loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize);
if (idx->getObjectSpace().isQintObjectType()) {
idx->saveIndex(database);
idx->close();
auto append = true;
auto refinement = false;
if (!dataFile.empty()) {
appendFromTextObjectFile(database, dataFile, dataSize, append, refinement, prop.threadPoolSize);
}
} else {
loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize);
}
} catch(Exception &err) {
delete idx;
redirector.end();
Expand All @@ -169,7 +179,17 @@ NGT::Index::createGraph(const string &database, NGT::Property &prop, const strin
StdOstreamRedirector redirector(redirect);
redirector.begin();
try {
loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize);
if (idx->getObjectSpace().isQintObjectType()) {
idx->saveIndex(database);
idx->close();
auto append = true;
auto refinement = false;
if (!dataFile.empty()) {
appendFromTextObjectFile(database, dataFile, dataSize, append, refinement, prop.threadPoolSize);
}
} else {
loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize);
}
} catch(Exception &err) {
delete idx;
redirector.end();
Expand Down Expand Up @@ -248,10 +268,10 @@ NGT::Index::append(const string &database, const float *data, size_t dataSize, s
}

void
NGT::Index::appendFromRefinementObjectFile(const std::string &indexPath) {
NGT::Index::appendFromRefinementObjectFile(const std::string &indexPath, size_t threadSize) {
NGT::Index index(indexPath);
index.appendFromRefinementObjectFile();
index.createIndex();
index.createIndex(threadSize);
index.save();
index.close();
}
Expand Down Expand Up @@ -439,12 +459,12 @@ NGT::Index::insertFromRefinementObjectFile() {

void
NGT::Index::appendFromTextObjectFile(const std::string &indexPath, const std::string &data, size_t dataSize,
bool append, bool refinement) {
bool append, bool refinement, size_t threadSize) {
//#define APPEND_TEST

NGT::Index index(indexPath);
index.appendFromTextObjectFile(data, dataSize, append, refinement);
index.createIndex();
index.createIndex(threadSize);
index.save();
index.close();
}
Expand Down Expand Up @@ -612,10 +632,10 @@ NGT::Index::appendFromTextObjectFile(const std::string &data, size_t dataSize, b

void
NGT::Index::appendFromBinaryObjectFile(const std::string &indexPath, const std::string &data,
size_t dataSize, bool append, bool refinement) {
size_t dataSize, bool append, bool refinement, size_t threadSize) {
NGT::Index index(indexPath);
index.appendFromBinaryObjectFile(data, dataSize, append, refinement);
index.createIndex();
index.createIndex(threadSize);
index.save();
index.close();
}
Expand Down
10 changes: 5 additions & 5 deletions lib/NGT/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,14 +552,14 @@ namespace NGT {
#endif
static void append(const std::string &index, const std::string &dataFile, size_t threadSize, size_t dataSize);
static void append(const std::string &index, const float *data, size_t dataSize, size_t threadSize);
static void appendFromRefinementObjectFile(const std::string &index);
static void appendFromRefinementObjectFile(const std::string &index, size_t threadSize = 0);
void appendFromRefinementObjectFile();
void insertFromRefinementObjectFile();
static void appendFromTextObjectFile(const std::string &index, const std::string &data,
size_t dataSize, bool append = true, bool refinement = false);
static void appendFromTextObjectFile(const std::string &index, const std::string &data, size_t dataSize,
bool append = true, bool refinement = false, size_t threadSize = 0);
void appendFromTextObjectFile(const std::string &data, size_t dataSize, bool append = true, bool refinement = false);
static void appendFromBinaryObjectFile(const std::string &index, const std::string &data,
size_t dataSize, bool append = true, bool refinement = false);
static void appendFromBinaryObjectFile(const std::string &index, const std::string &data, size_t dataSize,
bool append = true, bool refinement = false, size_t threadSize = 0);
void appendFromBinaryObjectFile(const std::string &data, size_t dataSize, bool apend = true, bool refinement = false);
static void remove(const std::string &database, std::vector<ObjectID> &objects, bool force = false);
static void exportIndex(const std::string &database, const std::string &file);
Expand Down
3 changes: 2 additions & 1 deletion lib/NGT/NGTQ/QbgCli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ class QbgCliBuildParameters : public QBG::BuildParameters {
transform(clusterDataType.begin(), clusterDataType.end(), clusterDataType.begin(), ::tolower);
if (clusterDataType == "-" || clusterDataType == "pq4") {
creation.localClusterDataType = NGTQ::ClusterDataTypePQ4;
} else if (clusterDataType == "sqsu8" || clusterDataType == "sqs8" || clusterDataType == "sq8") {
} else if (clusterDataType == "sqsu8" || clusterDataType == "sqs8" || clusterDataType == "sq8" ||
clusterDataType == "qsu8" || clusterDataType == "qs8") {
creation.localClusterDataType = NGTQ::ClusterDataTypeSQSU8;
} else if (clusterDataType == "nq") {
creation.localClusterDataType = NGTQ::ClusterDataTypeNQ;
Expand Down
53 changes: 35 additions & 18 deletions lib/NGT/PrimitiveComparator.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,21 @@ namespace NGT {
#if defined(NGT_NO_AVX)
template <typename OBJECT_TYPE, typename COMPARE_TYPE>
inline static double compareL2(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size) {
const OBJECT_TYPE *last = a + size;
const OBJECT_TYPE *lastgroup = last - 3;
auto *last = a + size;
auto *lastgroup = last - 3;
COMPARE_TYPE diff0, diff1, diff2, diff3;
double d = 0.0;
while (a < lastgroup) {
diff0 = static_cast<COMPARE_TYPE>(a[0] - b[0]);
diff1 = static_cast<COMPARE_TYPE>(a[1] - b[1]);
diff2 = static_cast<COMPARE_TYPE>(a[2] - b[2]);
diff3 = static_cast<COMPARE_TYPE>(a[3] - b[3]);
diff0 = static_cast<COMPARE_TYPE>(a[0]) - b[0];
diff1 = static_cast<COMPARE_TYPE>(a[1]) - b[1];
diff2 = static_cast<COMPARE_TYPE>(a[2]) - b[2];
diff3 = static_cast<COMPARE_TYPE>(a[3]) - b[3];
d += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
a += 4;
b += 4;
}
while (a < last) {
diff0 = static_cast<COMPARE_TYPE>(*a++ - *b++);
diff0 = static_cast<COMPARE_TYPE>(*a++) - static_cast<COMPARE_TYPE>(*b++);
d += diff0 * diff0;
}
return sqrt(static_cast<double>(d));
Expand All @@ -148,6 +148,9 @@ namespace NGT {
return compareL2<bfloat16, float>(a, b, size);
}
#endif
inline static double compareL2(const quint8 *a, const quint8 *b, size_t size) {
return compareL2<quint8, float>(a, b, size);
}
#else
inline static double compareL2(const float *a, const float *b, size_t size) {
const float *last = a + size;
Expand Down Expand Up @@ -407,7 +410,7 @@ namespace NGT {

inline static double compareL2(const qsint8 *a, const quint8 *b, size_t size) {
NGTThrowException("Not supported.");
return 0.00;
return 0.0;
}

template <typename OBJECT_TYPE>
Expand All @@ -422,15 +425,15 @@ namespace NGT {

template <typename OBJECT_TYPE, typename COMPARE_TYPE>
static double compareL1(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size) {
const OBJECT_TYPE *last = a + size;
const OBJECT_TYPE *lastgroup = last - 3;
auto *last = a + size;
auto *lastgroup = last - 3;
COMPARE_TYPE diff0, diff1, diff2, diff3;
double d = 0.0;
while (a < lastgroup) {
diff0 = (COMPARE_TYPE)(a[0] - b[0]);
diff1 = (COMPARE_TYPE)(a[1] - b[1]);
diff2 = (COMPARE_TYPE)(a[2] - b[2]);
diff3 = (COMPARE_TYPE)(a[3] - b[3]);
diff0 = (COMPARE_TYPE)(a[0]) - b[0];
diff1 = (COMPARE_TYPE)(a[1]) - b[1];
diff2 = (COMPARE_TYPE)(a[2]) - b[2];
diff3 = (COMPARE_TYPE)(a[3]) - b[3];
d += absolute(diff0) + absolute(diff1) + absolute(diff2) + absolute(diff3);
a += 4;
b += 4;
Expand Down Expand Up @@ -464,6 +467,12 @@ namespace NGT {
return compareL1<bfloat16, float>(a, b, size);
}
#endif
inline static double compareL1(const quint8 *a, const quint8 *b, size_t size) {
return compareL1<quint8, float>(a, b, size);
}
inline static double compareL1(const qsint8 *a, const qsint8 *b, size_t size) {
return compareL1<qsint8, float>(a, b, size);
}
#else
inline static double compareL1(const float *a, const float *b, size_t size) {
__m256 sum = _mm256_setzero_ps();
Expand Down Expand Up @@ -732,6 +741,14 @@ namespace NGT {
return sum;
}

inline static double compareDotProduct(const qsint8 *a, const quint8 *b, size_t size) {
double sum = 0.0;
for (size_t loc = 0; loc < size; loc++) {
sum += static_cast<int32_t>(a[loc]) * static_cast<int32_t>(b[loc]);
}
return sum;
}

template <typename OBJECT_TYPE>
inline static double compareCosine(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size) {
double normA = 0.0;
Expand Down Expand Up @@ -1153,6 +1170,7 @@ namespace NGT {
inline static double compareCosine(const qsint8 *a, const qsint8 *b, size_t size) {
return compareCosine(reinterpret_cast<const uint8_t*>(a), reinterpret_cast<const uint8_t*>(b), size);
}
#endif // #if defined(NGT_NO_AVX)

inline static double compareNormalizedCosineSimilarity(const float *a, const float *b, size_t size) {
auto v = 1.0 - compareDotProduct(a, b, size);
Expand Down Expand Up @@ -1182,7 +1200,6 @@ namespace NGT {
auto v = max - compareDotProduct(a, b, size);
return v;
}
#endif // #if defined(NGT_NO_AVX)

template <typename OBJECT_TYPE>
inline static double compareAngleDistance(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size) {
Expand Down Expand Up @@ -1512,14 +1529,14 @@ namespace NGT {
class L1Qsint8 {
public:
inline static double compare(const void *a, const void *b, size_t size) {
NGTThrowException("Not supported.");
return PrimitiveComparator::compareL1((const qsint8*)a, (const qsint8*)b, size);
}
};

class CosineSimilarityQsint8 {
public:
inline static double compare(const void *a, const void *b, size_t size) {
NGTThrowException("Not supported.");
return PrimitiveComparator::compareCosineSimilarity((const qsint8*)a, (const qsint8*)b, size);
}
};

Expand Down Expand Up @@ -1564,7 +1581,7 @@ namespace NGT {
class NormalizedCosineSimilarityQsint8 {
public:
inline static double compare(const void *a, const void *b, size_t size) {
float max = 127.0 * 127.0 * size;
float max = 127.0 * 255.0;
auto d = max - PrimitiveComparator::compareDotProduct((const qsint8*)a, (const qsint8*)b, size);
return d;
}
Expand Down

0 comments on commit fae388a

Please sign in to comment.