Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector more functions #216

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 274 additions & 0 deletions libsql-sqlite3/src/vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,33 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){
return 0;
}

void vectorMult(Vector *pVector, double k){
switch (pVector->type) {
case VECTOR_TYPE_FLOAT32:
vectorF32Mult(pVector, k);
break;
case VECTOR_TYPE_FLOAT64:
vectorF64Mult(pVector, k);
break;
default:
assert(0);
}
}

void vectorAdd(Vector *v1, const Vector *v2){
assert( pVector1->type == pVector2->type );
assert( pVector1->dims == pVector2->dims );
switch (v1->type) {
case VECTOR_TYPE_FLOAT32:
vectorF32Add(v1, v2);
break;
case VECTOR_TYPE_FLOAT64:
vectorF64Add(v1, v2);
break;
default:
assert(0);
}
}
const char *sqlite3_type_repr(int type){
switch( type ){
case SQLITE_NULL:
Expand Down Expand Up @@ -590,6 +617,250 @@ static void vectorDistanceCosFunc(
}
}

/*
** Implementation of vector_sum(V...) scalar function.
*/
static void vectorSumFunc(
sqlite3_context *context,
int argc,
sqlite3_value **argv
){
char *pzErrMsg = NULL;
Vector *pSum = NULL, *pVector = NULL;
int i;
int typeSum, dimsSum, typeVector, dimsVector;

if( argc < 1 ){
return;
}
if( detectVectorParameters(argv[0], 0, &typeSum, &dimsSum, &pzErrMsg) != 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
}
pSum = vectorContextAlloc(context, typeSum, dimsSum);
if( pSum == NULL ){
goto out_free;
}
if( vectorParse(argv[0], pSum, &pzErrMsg) < 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
}
pVector = vectorContextAlloc(context, typeSum, dimsSum);
if( pVector == NULL ){
goto out_free;
}
for(i = 1; i < argc; i++){
if( detectVectorParameters(argv[i], 0, &typeVector, &dimsVector, &pzErrMsg) != 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
}
if( typeSum != typeVector ){
pzErrMsg = sqlite3_mprintf("vector_sum: vectors must have the same type: %d != %d", typeSum, typeVector);
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
}
if( dimsSum != dimsVector ){
pzErrMsg = sqlite3_mprintf("vector_sum: vectors must have the same length: %d != %d", dimsSum, dimsVector);
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
}
if( vectorParse(argv[i], pVector, &pzErrMsg) < 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
}
vectorAdd(pSum, pVector);
}
vectorSerialize(context, pSum);
out_free:
if( pSum != NULL ){
vectorFree(pSum);
}
if( pVector != NULL ){
vectorFree(pVector);
}
}

struct VectorSumCtx {
i64 count;
Vector *pSum;
Vector *pVector;
};

static void vectorSumAdd(
sqlite3_context *context,
int argc,
sqlite3_value **argv,
double k
){
char *pzErrMsg;
struct VectorSumCtx *p;
int type, dims;
assert( argc == 1 );
UNUSED_PARAMETER(argc);
p = sqlite3_aggregate_context(context, sizeof(*p));
if( detectVectorParameters(argv[0], 0, &type, &dims, &pzErrMsg) != 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
return;
}
if( p->count == 0 ){
p->pSum = vectorContextAlloc(context, type, dims);
if( p->pSum == NULL ){
return;
}
}
if( p->pSum->type != type ){
pzErrMsg = sqlite3_mprintf("vector_sum: vectors must have the same type: %d != %d", p->pSum->type, type);
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
return;
}
if( p->pSum->dims != dims ){
pzErrMsg = sqlite3_mprintf("vector_sum: vectors must have the same length: %d != %d", p->pSum->dims, dims);
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
return;
}
if( p->count == 0 ){
if( vectorParse(argv[0], p->pSum, &pzErrMsg) < 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
}else{
vectorMult(p->pSum, k);
p->count++;
}
return;
}
if( p->pVector == NULL ){
p->pVector = vectorContextAlloc(context, type, dims);
if( p->pVector == NULL ){
return;
}
}
if( vectorParse(argv[0], p->pVector, &pzErrMsg) < 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
return;
}
vectorMult(p->pVector, k);
vectorAdd(p->pSum, p->pVector);
p->count++;
}

static void vectorSumEnd(sqlite3_context *context, int freeMem){
struct VectorSumCtx *p;
p = sqlite3_aggregate_context(context, 0);
if( p && p->count>0 ){
vectorSerialize(context, p->pSum);
}
if( p && p->pSum != NULL && freeMem ){
vectorFree(p->pSum);
}
if( p && p->pVector != NULL && freeMem ){
vectorFree(p->pVector);
}
}

/*
** Implementation of vector_sum aggregate function (step part)
*/
static void vectorSumStep(sqlite3_context *context, int argc, sqlite3_value **argv){
vectorSumAdd(context, argc, argv, 1.0);
}

/*
** Implementation of vector_sum aggregate function (inverse part)
*/
static void vectorSumInverse(sqlite3_context *context, int argc, sqlite3_value **argv){
vectorSumAdd(context, argc, argv, -1.0);
}

/*
** Implementation of vector_sum aggregate function (finalize part)
*/
static void vectorSumFinalize(sqlite3_context *context){
vectorSumEnd(context, 1);
}

/*
** Implementation of vector_sum aggregate function (value part)
*/
static void vectorSumValue(sqlite3_context *context){
vectorSumEnd(context, 0);
}

/*
** Implementation of vector_mult(V, k) / vector_mult(k, V) function.
*/
static void vectorMultFunc(
sqlite3_context *context,
int argc,
sqlite3_value **argv
){
char *pzErrMsg;
sqlite3_value *pMultValue = NULL, *pVectorValue = NULL;
int type, dims;
Vector *pVector;
double k;

assert( argc == 2 );

if( sqlite3_value_type(argv[0]) == SQLITE_INTEGER || sqlite3_value_type(argv[0]) == SQLITE_FLOAT ){
pMultValue = argv[0];
}
if( sqlite3_value_type(argv[1]) == SQLITE_INTEGER || sqlite3_value_type(argv[1]) == SQLITE_FLOAT ){
pMultValue = argv[1];
}
if( sqlite3_value_type(argv[0]) == SQLITE_BLOB || sqlite3_value_type(argv[0]) == SQLITE_TEXT ){
pVectorValue = argv[0];
}
if( sqlite3_value_type(argv[1]) == SQLITE_BLOB || sqlite3_value_type(argv[1]) == SQLITE_TEXT ){
pVectorValue = argv[1];
}
if( pMultValue == NULL || pVectorValue == NULL ){
pzErrMsg = sqlite3_mprintf(
"vector_mult: unexpected parameters: got %s and %s, but expected vector-compatible and float-compatible types",
sqlite3_type_repr(sqlite3_value_type(argv[0])),
sqlite3_type_repr(sqlite3_value_type(argv[1]))
);
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
return;
}

if( detectVectorParameters(pVectorValue, 0, &type, &dims, &pzErrMsg) != 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
return;
}
if( sqlite3_value_type(pMultValue) == SQLITE_INTEGER ){
k = sqlite3_value_int64(pMultValue);
}
if( sqlite3_value_type(pMultValue) == SQLITE_FLOAT ){
k = sqlite3_value_double(pMultValue);
}
pVector = vectorContextAlloc(context, type, dims);
if( pVector == NULL ){
return;
}
if( vectorParse(pVectorValue, pVector, &pzErrMsg)<0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
}

vectorMult(pVector, k);
vectorSerialize(context, pVector);
out_free:
vectorFree(pVector);
}

/*
* Marker function which is used in index creation syntax: CREATE INDEX idx ON t(libsql_vector_idx(emb));
*/
Expand All @@ -607,7 +878,10 @@ void sqlite3RegisterVectorFunctions(void){
FUNCTION(vector32, 1, 0, 0, vector32Func),
FUNCTION(vector64, 1, 0, 0, vector64Func),
FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc),
FUNCTION(vector_sum, -1, 0, 0, vectorSumFunc),
FUNCTION(vector_mult, 2, 0, 0, vectorMultFunc),
FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc),
WAGGREGATE(vector_sum, 1, 0, 0, vectorSumStep, vectorSumFinalize, vectorSumFinalize, vectorSumInverse, SQLITE_FUNC_ANYORDER),

FUNCTION(libsql_vector_idx, -1, 0, 0, libsqlVectorIdx),
};
Expand Down
13 changes: 13 additions & 0 deletions libsql-sqlite3/src/vectorInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ float vectorDistanceL2 (const Vector *, const Vector *);
float vectorF32DistanceL2 (const Vector *, const Vector *);
double vectorF64DistanceL2(const Vector *, const Vector *);

/*
* Multiply vector in-place by floating point constant k
*/
void vectorMult (Vector *, double);
void vectorF32Mult(Vector *, double);
void vectorF64Mult(Vector *, double);

/*
* Add second vector argument to first vector in-place
*/
void vectorAdd (Vector *, const Vector *);
void vectorF32Add(Vector *, const Vector *);
void vectorF64Add(Vector *, const Vector *);
/*
* Serializes vector to the sqlite_blob in little-endian format according to the IEEE-754 standard
* LibSQL can append one trailing byte in the end of final blob. This byte will be later used to determine type of the blob
Expand Down
25 changes: 25 additions & 0 deletions libsql-sqlite3/src/vectorfloat32.c
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ float vectorF32DistanceL2(const Vector *v1, const Vector *v2){
return sqrt(sum);
}

void vectorF32Mult(Vector *v, double k){
float *e = v->data;
int i;

assert( v->type == VECTOR_TYPE_FLOAT32 );

for(i = 0; i < v->dims; i++){
e[i] *= k;
}
}

void vectorF32Add(Vector *v1, const Vector *v2){
float *e1 = v1->data;
float *e2 = v2->data;
int i;

assert( v1->type == VECTOR_TYPE_FLOAT32 );
assert( v1->type == v2->type );
assert( v1->dims == v2->dims );

for(i = 0; i < v1->dims; i++){
e1[i] += e2[i];
}
}

void vectorF32InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){
pVector->dims = nBlobSize / sizeof(float);
pVector->data = (void*)pBlob;
Expand Down
25 changes: 25 additions & 0 deletions libsql-sqlite3/src/vectorfloat64.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,31 @@ double vectorF64DistanceL2(const Vector *v1, const Vector *v2){
return sqrt(sum);
}

void vectorF64Mult(Vector *v, double k){
double *e = v->data;
int i;

assert( v->type == VECTOR_TYPE_FLOAT64 );

for(i = 0; i < v->dims; i++){
e[i] *= k;
}
}

void vectorF64Add(Vector *v1, const Vector *v2){
double *e1 = v1->data;
double *e2 = v2->data;
int i;

assert( v1->type == VECTOR_TYPE_FLOAT64 );
assert( v1->type == v2->type );
assert( v1->dims == v2->dims );

for(i = 0; i < v1->dims; i++){
e1[i] += e2[i];
}
}

void vectorF64InitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){
pVector->dims = nBlobSize / sizeof(double);
pVector->data = (void*)pBlob;
Expand Down
Loading
Loading