From 26e4cf80482ffb51bf7f34b9e428f211717c267f Mon Sep 17 00:00:00 2001 From: ogzhanolguncu Date: Mon, 13 May 2024 17:33:10 +0300 Subject: [PATCH] feat: allow custom metadata key --- .husky/pre-commit | 2 +- bun.lockb | Bin 218660 -> 218636 bytes package.json | 5 ++-- src/constants.ts | 9 ++++++ src/rag-chat-base.ts | 4 ++- src/rag-chat.test.ts | 59 ++++++++++++++++++++++++++------------ src/rag-chat.ts | 16 +++++++---- src/services/history.ts | 9 ++---- src/services/retrieval.ts | 57 +++++++++++++++++++++++++++--------- src/types.ts | 14 ++++++++- src/utils.ts | 17 +++++++++-- 11 files changed, 139 insertions(+), 53 deletions(-) diff --git a/.husky/pre-commit b/.husky/pre-commit index 1e0f6f3..4e52d65 100755 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1,4 +1,4 @@ #!/bin/sh . "$(dirname "$0")/_/husky.sh" -bun run lint && bun run fmt && bun test +bun run lint && bun run fmt && bun test \ No newline at end of file diff --git a/bun.lockb b/bun.lockb index 5fa88c752ea35199bc10cbaaf99d537add2e42d8..38f9dd8d7a9799bbca1d826cdc3b233e2a810b11 100755 GIT binary patch delta 14006 zcmeHOYjjjawmwyPKpPMdLI@FpKmq~s4kQ7gK}O~M2w;$>3cjo@ME4A|VS9{m4 zUAyX3oztm0z2Vu0OP+1GY_NKuXvi>ix@cWkVVPxF*ICy5r~u1q3O)pG3~nksJg%HJ7q|i+Vc`7y-_Al0 zto(wkoY`5Hbr~_TJ4eB^-wO5*_l&()wJ&@F+UrreORSfJ!y#XSl3o`;4h1hqFr%1>k;LE>&Yv`G zV$MBqlnDhJD1iY-c?p=KQvhxTo&x4*UW6^Xa}Jk?S7~;B?2PO=S+kBqV8aK&?9g|J zlEdSHOn(dDg#I^+-Spgi3=%?m8xGmfX}d;yVmdVU#<{nZw;7-OIO;O~haU?x6FKZ?R{jcUOsad5l|S*;V*_&5r5lR-tw_K8M=n9o?>l=xH1}tL>WT zRCV5t>Eu??c6KMXD;JZm8%hH0y!2GWQw7~k7{Ua{n!15@c4xQ#HpG6qb5B=ST-07D z3a}H~r>X)wCf4oR4&AMeZhOa6*XNM>Iua_ox?>ilIj#~rrK)^8Ce9u7EOqw9cHPxQ zyFAYAYK@8TE7*8Lm~6+yyH%;3jn81aJl^dZiV1VO*GGJ++G58fxK#r?JHhSBz*I~6 zvF1X%3MFsZF^TS=CYUQ2sMzj7cOlo6IrT@oD$(uQ2e~KeSax2gR97e_9b>ob?Q_xs zA>eAoc6Ysj932PQHBY9g8|{wCJzag9;7UWM?Cl7n*e*|YyH1JDvJ+!dT^%uLIXH0E zF;z{svs2uzZ=o9sT>wYk)gRNH!x3!Pv`ck81gVGPv8HXR+GA&Tak~Ps0KC_O_B|o> zaV)lXNL4HBn67Tu$I!8sw)1ca+Suh?-73Sb!e@yclj>I0mD#B~=M(80+K&dzYn0-{Qy&K%+?P8v_A+i`LXv3P(ep;3cm$rB6A@frhxQIcoOVV z=^ABSrQ%FElCSrFjAx@(M5upA-u+^{c^Ll4YXzXR0>f zgZz|~lQ{=BiF^&#mXRK)bS=q1e@XpS7DJAo+a=Xa?2kgDw z{5u=s!*wt^z(E`$9U)U5Dl(bb42|}g86gF=abuKcp*$M=2eBjbVi$=_X4b|B8!8t4 zwU~BKw3CWPD%8ff8SM5Q+C3w(KU1|u^kmvSC-ODe9=@xwbLXxQTQcibf~k5z6PWcoQHGMU-G;DdJO!1PxIrd>4)$%Vd*6nrB5SJC@39bJZ= zMV|?OA$I;u)mNe?v*B+<_GbnX`z;EX`A!UKV;22DJA2FSFs?4IPaoT_?2dH4N^7*l zAcMhd)Wwf~!EF0Fu_Lps24LzNg1H8oh&~j|MCM(iC75=t0`Lz4E4GHBAvjhHx`3Jd znRoPV(9`}FFdg&|*)4J}Fq1#i{#GgP1Lo7azwodC?`_cV{h6NUiJr_M;0=TR zekXcA_KpD+G_WsbAHrq6Vz=J z)NK>^Zx3*FQ079Z+a|yuaJztmQMXM{w@tu(1pn;;h>pyE%tb8pZ$P&G6Ym#eAf`(pJ{&{<4~P*Y=+vCp&(bot&FYHMm%s0ed; zwVF)5YT-G%T!kq0k>{dO0ZKXLkqv#X6De#oH1!m1QiY10FIOwZhR>xDeiXHv#@5Z-k(KbI9-{8_9w!fdCD7bzq0w24`~P-3-!Is5;u=;-o* zbmTkH@t>@tM#aI!yHdo1DZ0RGWmZS%c&>g$bUcb~D0WvxhgaRKcg--m7z-8quV6yY zErD_Lh2)$Tf$^f2RYC<4Px@KRskp|hql*U-H2WGMMK>cq7CQEoM@#s#+5p34ELw?O zTjHq=CFXch$>MvK)b$n}xAUsa;ROA#?h{}}C+Zf#-BG`FfO#-c zcMQEAd9F-uc835wshVnbBi!W4k0n*I6B<(Idf}oB9T#d(oY?0GzIzo_N8_S$AA#oNHV_TrgZ9Hvw1;-JSj7h3M`gfa`fO zyADhPvVrM94!|AqOkfny3#d?LWwLIi`v;hplXZ*2TcC>r;sL%y!K0mRXoL&=O<)hu z65#sS3G4!10bXJr_OAfD0iJO^4{Qat0c(Mkz$#z`z%^%pBH(eL7+9k5X4+C@mI2QB zID+8Z;4Oq-CNlsYjotzD0|o#Cfk8mLS)8KB+#G^z1AupvM!*d~J)k}?6>(*opf1`S z%>QPODAT_Sy1W|d5?~Fm7AOVQ0qcRyKsnGKf$#+h-gWrm#72NGWE2B@RfF%*3<2JP zT?N1$z1M*^fW5$Lz#iaL;7MR5unJhscdvNcS_7;F__|aP@EE|iwjKlufMGxeFc267 z3;?*iZ46H_z~g{NfW^QZU@mYUz;lV&0M9FU8ZsIf1H6o2d9V9|?<4R5(E~zH;2n(2 zTR>}oceO}>2abF=>;S(C-UE~(zYbUr@Xq)IupHnaRRF-(o`V1$?Y##K1DXRodTtAZ z0^CaoEW}3;kdNN{2H+F<48X@P@0Waa;{)IXunph~DE(l^_h9(i(g@%n-AG1TF!5IOW0}-w@$*|0FOHcoVv3fMV zTYzVwy9#^*d3e|Redxai;5V;xusIK00IGp1;1i&4H+&%dsjsO{&~1WQuSR(?y6N9v zuST0c_tov;*)muA>U5LQPdAImfrz{L{3(BSSB#8J`?$UupBk4Cho|Bq^SgdJSv_Vv z{q$5-W+MCR6ZeN|b=Fk&*Gbo@@m=5okL4*w-KoPPeDB$h+?e+0tdF;! z(#mN9)t8tdFi<6?#{k&*zHvWt@0IFz%HDs6c3tBVV7JTU4M6Az%u@sOAa%@qPU*C1 zKTuClHRj=gda8>{g3;jn$RN1<+%zTynogvqrsYsjl$ii2f*0;On68bQ-?`JGtY0@( zT#8Im`Hc4H2R>^TGaY{`F zhUlAq)~bZ{^N^i>gqxv5bqjUU95(+D=_a#wC?c&eM~3Q7Ll3%~WjWvXBM^N&Z@%fB z)CnOZ#wD_6zMq9$*!*atW5hxzz)*bz+qVKmWJ;&=^M=#Yst$&;hj+12XD2FQv*9kH%IFXG?Y~D6a=Mg$ZT{06# zpkc>2tYwPj`)$jrP)~J>MZ4}ny9sg0akw=&{SJ@t{pjWFl|#>UNNKnN28n*-6Bdvl zmZsTAT*YH1aU@11&=iiuP-U7*P=x&8Wq#XsrM>kdzj@5@fi>(~fU>?!%AN4wEIE!c zR#dp!Y)ZI7eZORxQF!I*ZGBgAczDf|@f0+Nxk8VdPwqs!=S+I0PS)Rrnwgopd4%tW zFO$XwE}i_+N0*)6CB&uqs?l&)rcP3O&819udEB^0!OLlrIZ7ut@%{d#D1P+P=ofmV zYZd7!8KpNV-J!X8yF@qlj32Gj6l$3<`hV1E&(yK{qVoBA+vJYdtDp^;puY@I2aTG< zI6V=Q^uZt{%jElMSSH???@7wm8IJ1{Il8%7GhIKV4tp@US}L(HqjK~^Y#Rp8tF4}* z8G2lxzb8+dS^86@rg$39W=WaJor5Nqdlt{p_eA)dTwknHBK~=$j8E~-QwgS^SdR|y zRvE-5Td&YbKXyB3g)UX9%+qkCzDKDX&)ijdq(-R+gHjMkv%c#wsAcB<^?HFfj8M;w zW%}2FuFadg&x>hh-X`70?A`>sbW^iQ|IxLA&vB&LX7Oe{*KBYHwD7bq*9XEpAHJ%4 GKk;{zUl?Wp delta 14188 zcmeHOd3+T`y6vhgaDf<*C4`^>SwaZ8*%yp(MR^Fw5|#k61Or64LI@!!f)4|}=YoO| z3Mhgi1VjiT0zobWP(WsI6m>vgKu`!ePs1`0S4JJ3bNbeGt{Q#(9Ou8(FQ-p^OMP8c zU0vO&>N{^^$a5P*7WNa1i~A1{=Zj0jiZ)r6RnM|!#|Kze6YvReBXFeQ`vNWNddQzi z%c>8qf*b-qZ8&%GH1{-jZs9vd4n?_!u+Ith#yJ5Fg*>*faN3M1xz=}x!}wG33#UPM zll4g*Zv`gjOwM}UCW99SE{>2WUqFpBeex!2Y3$-Vc^2T=~K}HtMLBZ zNz-yI>l@UO%{c?6e+AgTx);$_Ob@EcjkbbfeXOw&8zTbbE6aw-@=T)`zME8NjdhWZ?dm5j-0%fUxEYchHkUAXjs@PFp z9GEXfLAj`tA1F%73+l%OHd=@fJagB@UY@SDgvN=X<&~jznutcKs*Owhmx^xd5)G6a z#|%}5|7j%Ru6S4>XtxcdfKWlk-M2W z5wEJ+yTnu#9pw_+lsn2LMX=fs)itOK23~jDc@c8nNFO72n_Vngm343hjmB8BEn<|uYD1^4j&8WNx55;l z5~9+?92Fhwa_olgHe0tPJ-sf3p0>cSJusHi?MMkRY2qOj?Q{iss8g5PW{9h*%;|E( zU^sM#k6jUwucG5z;sxc#v5_i^b2&y}*!1wmh)WYYRCK&cgeiBt%P|TAF7bzw3se;> z$5nKKE9gdy7<5-`M$kCqQkfIWRaJt^@iyeHC}XJt)CgUb=n9&P0f?&w@59I?s@S*; z8PHgDOzP?w)EHM8?m}(B$ychfB$uPg=q!~Go90NtfMt(ZYD>p7ai4N0yBtBd1v8-w zV6Qu}Ah8pIRdw4m$9za#?19~ZJ*c$W+G!)a+ugByc2AxOPeE$SiQcK zkjO0EAD|xnY4ZRu6Pfyf0R0977)jO`fM>_-VcYKkGm+^(o&pk?<=hmIz(saDCwpa! z%mVp_r+`^tD!@dhJPqIprvps>O#2Lg_WuSLZx+Dxa~|Y%oNb$Nv9E!7VIBb(_)!4U zf;HdBPk}8pYHL^#GYbHEEi_yVX7v=nL}s*Q0R2h;JlL#KfQd~1HHO!Mnf#gdb+-0L zjJd@ag!;$UuyBUoVC%fhVjF<~6<5_qnfe0&1xUI*y^1_uKZGIsK326_`9AF{1p4o3cnk*|fv!Z(FMe!?vv z_ZOIh1Cbd7hX9qg&3khUGw%Kz1g}xhqkuKR*vZ#Jo@wOYfxZ3lTLR(HY!#b|{C4)9 z|91$4+WeLZDaZE4Zni?1rjR*imKrVr)3lri!?dq}%-ixgqxWa3*5W{Z-q^`ft=Jx> zHg2y=^rE56SpAtJ^#$m8mQ7%em#xOnpE=p?qF#Nm-8+u<7_Vz#M&Ap67TIUYk(u3Z zhCy^zczL<{r_R~-x)obdcm`ki+mYBW;Vdc zWR?vCvtS*gCsSY7$TcwsKm%j1iP_+y>y1NA%&rZCoi5?VkIWOdFdSjHrLmJ~zrpa0 zMo(t8mC?8M>Vqw2Fu-?4wq~(*UiFWd1J+5sI=0h0VB?KfP0ab_4rBihuy?@x2Lhwo z-iq@v+aC>JcVwD|kSPz4X21@Bg0?}1vnU|d#CXTG?t#5Mc%Jbi^ZG9`GMU-MIPgSE zjQ(fb#Lqy-rI2y?cE#~3StYcpJ-fUw8OtHj=V|(YSMXq%`e%)OrQub^PNuyS%t8O0 zk=KG*=@-CEWcqD1vL8!SDsOkB-CHjivp=)gW}_$bH04Gn)4mN%)pjF$C?Jt(-vOqo z!pOTQAdz|Ry(%%i(Q&Ua{2H?YuNgmoX15+PdNQ+b8=1`b?-<#isXENjLWd*9flT>b z!$*ytO#Lw<`!g$c!sy9t*=Zw_DW5emnQK;`44fRh0znb}cBu8CQ}Y~z<>{BnwX8898j7#?eQ9Jn6zlfg{>%u6&CdRAn*vHLT-e3sGs zGtcvo(HAkpE}0Ex!2dOdnwa*-Vdn|w8$W;M36;@b3&XX@T8bkBF9)+`Pn!T_R&0fl z{h5Bx8v824P}wU5jhR6r^UbXt1tc=dwKp6E4(3Z-EHX@Fp6lOOHeeWgS%Gn`Y1M$E zyLQ=t-G*g^LEZ~6+ZHh-uYoxf9|3q}9|d?5*Df2>E*oGx)Giy~Uikf%8ETge(3`c( z23$ViWLvvz(2{Q_NM!a%?Xm%Th<# zD+2CpKR2vs#)l==>!u;!7#&Y@ObaJM95*K4zE(I)G0<`MoqJmh!42!{jNkW0hY#Yc z<9Yx?jDm{o7jR+Z8-Q$+W>2HHADEeB3c;}7q+w*+hVgrJcYj#xDu^3pyblAub{$Bc>Ts3i6kXF5T!- zp}TB!osBLHx+{7PLwp1k&)fyfPUlxaTD z1aPC<0Zh=26zK|{fye<;_e+s2#CzJEA|tZaK*jGOHUL}|UJvF3!!Ie;0sKN^39tz0 z4`c%TaHkyj1Mo-SAg~{J4cH6p(z#?D)z>l`D_-%jWH4r=q=nM1%`T(3k z7NgBe0IvDZ1)czA05gFH0d79r2XIq@J0K%~k-#ogixU_pd`|FPxzFQ*;$5_!ulAg# zI4!jTu+?aPo4~C|F0bqbN|9d;tN}RPai>-T_Qn+A1OdUo3FuA&1A%ZL6X2ZI3(=<0dGKl9rzP)0C)i?1J(g+ zff>L|;6Y$8a2GHf7y*m~vVqa;>KuSGDktI&0B2p!vYb`93!;IifCa!ppcq&LJPgbN zPNOC710Mi)L%tQ@K583)55`Sk{@d_3;0n+J;J1E!su!Kd;R9d@@FtWSffs>Iz<9s~ za6-NV$N;$g@(jRj8BWS)0d8VH53B>qfDHh5C$9p32daTj0Y2O|BgRWW3ebf^?L{a# z!E*cZOW-r$J5+|-mwd?xhdB}`h0g@wHh_=#9k6o)@)-Cp05^Pl0qub(paZ}M4j(X` zfLMSJ>bDT91b7&Dh<(mSFCVqLVdTrtTX5ua^Az9&ngHi?)6Q~4(L$JeqTKU9DKH1c z_`n+t@H?C5;QB1^9dH?7IDW{p2ENY$nZT_8pIj?p=aws9p!m4r{^7kN!+r(&2aJ5T;R*Ss#W$;u;q(bm1zZ5m0~djwH{*cx=brj}yu7I%hXc#y zTAqG4%SW#h)AX^PGHPirnVwJ|FQ~j1UpNPxrTu-|iWc7qdZ%L?CyD_#uF#>4Vf}rZ|sY? z-l)70eA0dTo&IotN5}P(NfEwZdh9RiSDAliS2j8?5lzR-P?esA`k7LCbH5+rMa&8I zR(X2rx492JdB^F0RmHT)_p^|-?axdQy*Cg0se||A;lFb7wR%o}^qi@Ksq&ATV>dKO ze4+=Sy*|&ax_A&;|7<2|yGQTLlrfH@4*OMkb~}9uj!xfiOg>rvMCkG3!BeF8)?ECl zxjAE};fgN&_{0a{3lA)ADC*VM$phrgVxG<)fOB|*NBDkEGNbzC6Q^9eAF{_=k~1OQ zs?d7|pwdk{?b8kn51;ksW`oqEKkV+kvg`No2xml3L=n~Ql{Zf&$;wR1D+_qx*1^D4kWT1IEZ>S7!`OATVo9Wd< zWuiEwGqTZHEA-`|hc$jS2m~T4%^C2#y=YNLDbwc*BJe`KiOyMonOsf%c zrzz?gH&PxKzNn{kRhu3o1Rt^WuPz~y6YtQuyJUv-s3Fkq$dia5m=c_^}BluLW zhwHMk-Rw>=kK~SNBD!`-cBY`$NAU1fB{xt7Q{Uo#k?<5Ys)QpO#tj2bZ^3$m~E$&egakw(A_R z$@AfAIX%#^m+urvv-Q|@@^-y`os1XV^~rT|fn)P}uW6nxSTAP;#ZAs1m#Y(80gZKu TD { const question = sanitizeQuestion(input); const facts = await this.retrievalService.retrieveFromVectorDb({ question, similarityThreshold, + metadataKey, topK, }); return { question, facts }; @@ -76,7 +78,7 @@ export class RAGChatBase { getMessageHistory: (sessionId: string) => this.historyService.getMessageHistory({ sessionId, - length: chatOptions.includeHistory, + length: chatOptions.historyLength, }), inputMessagesKey: "question", historyMessagesKey: "chat_history", diff --git a/src/rag-chat.test.ts b/src/rag-chat.test.ts index f614530..e23cb3e 100644 --- a/src/rag-chat.test.ts +++ b/src/rag-chat.test.ts @@ -9,8 +9,13 @@ import type { StreamingTextResponse } from "ai"; import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./constants"; import { RatelimitUpstashError } from "./error"; import { PromptTemplate } from "@langchain/core/prompts"; +import { sleep } from "bun"; describe("RAG Chat with advance configs and direct instances", async () => { + const vector = new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }); const ragChat = await RAGChat.initialize({ email: process.env.UPSTASH_EMAIL!, token: process.env.UPSTASH_TOKEN!, @@ -21,10 +26,7 @@ describe("RAG Chat with advance configs and direct instances", async () => { temperature: 0, apiKey: process.env.OPENAI_API_KEY, }), - vector: new Index({ - token: process.env.UPSTASH_VECTOR_REST_TOKEN!, - url: process.env.UPSTASH_VECTOR_REST_URL!, - }), + vector, redis: new Redis({ token: process.env.UPSTASH_REDIS_REST_TOKEN!, url: process.env.UPSTASH_REDIS_REST_URL!, @@ -33,10 +35,15 @@ describe("RAG Chat with advance configs and direct instances", async () => { beforeAll(async () => { await ragChat.addContext( - "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall." + "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", + "text" ); + //eslint-disable-next-line @typescript-eslint/no-magic-numbers + await sleep(3000); }); + afterAll(async () => await vector.reset()); + test("should get result without streaming", async () => { const result = (await ragChat.chat( "What year was the construction of the Eiffel Tower completed, and what is its height?", @@ -104,6 +111,11 @@ describe("RAG Chat with ratelimit", async () => { token: process.env.UPSTASH_REDIS_REST_TOKEN!, url: process.env.UPSTASH_REDIS_REST_URL!, }); + const vector = new Index({ + token: process.env.UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.UPSTASH_VECTOR_REST_URL!, + }); + const ragChat = await RAGChat.initialize({ email: process.env.UPSTASH_EMAIL!, token: process.env.UPSTASH_TOKEN!, @@ -114,10 +126,7 @@ describe("RAG Chat with ratelimit", async () => { temperature: 0, apiKey: process.env.OPENAI_API_KEY, }), - vector: new Index({ - token: process.env.UPSTASH_VECTOR_REST_TOKEN!, - url: process.env.UPSTASH_VECTOR_REST_URL!, - }), + vector, redis, ratelimit: new Ratelimit({ redis, @@ -128,20 +137,32 @@ describe("RAG Chat with ratelimit", async () => { afterAll(async () => { await redis.flushdb(); + await vector.reset(); }); - test("should throw ratelimit error", async () => { - await ragChat.chat( - "What year was the construction of the Eiffel Tower completed, and what is its height?", - { stream: false } - ); + test( + "should throw ratelimit error", + async () => { + await ragChat.addContext( + "Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall.", + "text" + ); + //eslint-disable-next-line @typescript-eslint/no-magic-numbers + await sleep(3000); - const throwable = async () => { - await ragChat.chat("You shall not pass", { stream: false }); - }; + await ragChat.chat( + "What year was the construction of the Eiffel Tower completed, and what is its height?", + { stream: false, metadataKey: "text" } + ); - expect(throwable).toThrowError(RatelimitUpstashError); - }); + const throwable = async () => { + await ragChat.chat("You shall not pass", { stream: false }); + }; + + expect(throwable).toThrowError(RatelimitUpstashError); + }, + { timeout: 10_000 } + ); }); describe("RAG Chat with instance names", async () => { diff --git a/src/rag-chat.ts b/src/rag-chat.ts index 893fc8a..74a64d9 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -5,6 +5,7 @@ import type { StreamingTextResponse } from "ai"; import { HistoryService } from "./services/history"; import { RateLimitService } from "./services/ratelimit"; +import type { AddContextPayload } from "./services/retrieval"; import { RetrievalService } from "./services/retrieval"; import { QA_TEMPLATE } from "./prompts"; @@ -50,8 +51,9 @@ export class RAGChat extends RAGChatBase { //Sanitizes the given input by stripping all the newline chars then queries vector db with sanitized question. const { question, facts } = await this.prepareChat({ question: input, - similarityThreshold: options.similarityThreshold, - topK: options.topK, + similarityThreshold: options_.similarityThreshold, + metadataKey: options_.metadataKey, + topK: options_.topK, }); return options.stream @@ -60,10 +62,12 @@ export class RAGChat extends RAGChatBase { } /** Context can be either plain text or embeddings */ - async addContext(context: string | number[]) { - const retrievalService = await this.retrievalService.addEmbeddingOrTextToVectorDb(context); - if (retrievalService === "Success") return "OK"; - return "NOT-OK"; + async addContext(context: AddContextPayload[] | string, metadataKey = "text") { + const retrievalServiceStatus = await this.retrievalService.addEmbeddingOrTextToVectorDb( + context, + metadataKey + ); + return retrievalServiceStatus === "Success" ? "OK" : "NOT-OK"; } /** diff --git a/src/services/history.ts b/src/services/history.ts index d876eab..ace3876 100644 --- a/src/services/history.ts +++ b/src/services/history.ts @@ -4,10 +4,7 @@ import { Config } from "../config"; import { ClientFactory } from "../client-factory"; import type { RAGChatConfig } from "../types"; -const DAY_IN_SECONDS = 86_400; -const TOP_6 = 5; - -type GetHistory = { sessionId: string; length?: number }; +type GetHistory = { sessionId: string; length?: number; sessionTTL?: number }; type HistoryInit = Omit & { email: string; token: string; @@ -19,10 +16,10 @@ export class HistoryService { this.redis = redis; } - getMessageHistory({ length = TOP_6, sessionId }: GetHistory) { + getMessageHistory({ length, sessionId, sessionTTL }: GetHistory) { return new CustomUpstashRedisChatMessageHistory({ sessionId, - sessionTTL: DAY_IN_SECONDS, + sessionTTL, topLevelChatHistoryLength: length, client: this.redis, }); diff --git a/src/services/retrieval.ts b/src/services/retrieval.ts index d6b860e..4e03e88 100644 --- a/src/services/retrieval.ts +++ b/src/services/retrieval.ts @@ -4,9 +4,9 @@ import type { RAGChatConfig } from "../types"; import { ClientFactory } from "../client-factory"; import { Config } from "../config"; import { nanoid } from "nanoid"; +import { DEFAULT_METADATA_KEY, DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "../constants"; -const SIMILARITY_THRESHOLD = 0.5; -const TOP_K = 5; +export type AddContextPayload = { input: string | number[]; id?: string; metadata?: string }; type RetrievalInit = Omit & { email: string; @@ -15,8 +15,9 @@ type RetrievalInit = Omit & { export type RetrievePayload = { question: string; - similarityThreshold?: number; - topK?: number; + similarityThreshold: number; + metadataKey: string; + topK: number; }; export class RetrievalService { @@ -27,36 +28,64 @@ export class RetrievalService { async retrieveFromVectorDb({ question, - similarityThreshold = SIMILARITY_THRESHOLD, - topK = TOP_K, + similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD, + metadataKey = DEFAULT_METADATA_KEY, + topK = DEFAULT_TOP_K, }: RetrievePayload): Promise { const index = this.index; - const result = await index.query<{ value: string }>({ + const result = await index.query>({ data: question, topK, includeMetadata: true, includeVectors: false, }); - const allValuesUndefined = result.every((embedding) => embedding.metadata?.value === undefined); + const allValuesUndefined = result.every( + (embedding) => embedding.metadata?.[metadataKey] === undefined + ); + if (allValuesUndefined) { throw new TypeError(` - Query to the vector store returned ${result.length} vectors but none had "value" field in their metadata. - Text of your vectors should be in the "value" field in the metadata for the RAG Chat. + Query to the vector store returned ${result.length} vectors but none had "${metadataKey}" field in their metadata. + Text of your vectors should be in the "${metadataKey}" field in the metadata for the RAG Chat. `); } const facts = result .filter((x) => x.score >= similarityThreshold) - .map((embedding, index) => `- Context Item ${index}: ${embedding.metadata?.value ?? ""}`); + .map( + (embedding, index) => `- Context Item ${index}: ${embedding.metadata?.[metadataKey] ?? ""}` + ); return formatFacts(facts); } - async addEmbeddingOrTextToVectorDb(input: string | number[]) { + async addEmbeddingOrTextToVectorDb( + input: AddContextPayload[] | string, + metadataKey = "text" + ): Promise { if (typeof input === "string") { - return this.index.upsert({ data: input, id: nanoid(), metadata: { value: input } }); + return this.index.upsert({ + data: input, + id: nanoid(), + metadata: { [metadataKey]: input }, + }); } - return this.index.upsert({ vector: input, id: nanoid(), metadata: { value: input } }); + const items = input.map((context) => { + const isText = typeof context.input === "string"; + const metadata = context.metadata + ? { [metadataKey]: context.metadata } + : isText + ? { [metadataKey]: context.input } + : {}; + + return { + [isText ? "data" : "vector"]: context.input, + id: context.id ?? nanoid(), + metadata, + }; + }); + + return this.index.upsert(items as Parameters[number]); } public static async init(config: RetrievalInit) { diff --git a/src/types.ts b/src/types.ts index 58823b1..c152b5c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -17,7 +17,12 @@ export type ChatOptions = { /** Length of the conversation history to include in your LLM query. Increasing this may lead to hallucinations. Retrieves the last N messages. * @default 5 */ - includeHistory?: number; + historyLength?: number; + + /** Configuration to retain chat history. After the specified time, the history will be automatically cleared. + * @default 86_400 // 1 day in seconds + */ + historyTTL?: number; /** Configuration to adjust the accuracy of results. * @default 0.5 @@ -33,6 +38,13 @@ export type ChatOptions = { * @default 5 */ topK?: number; + + /** Key of metadata that we use to store additional content . + * @default "text" + * @example {text: "Capital of France is Paris"} + * + */ + metadataKey?: string; }; export type PrepareChatResult = { diff --git a/src/utils.ts b/src/utils.ts index dd0e6b3..00ebb33 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,6 +1,14 @@ import type { BaseMessage } from "@langchain/core/messages"; import type { ChatOptions } from "./types"; -import { DEFAULT_CHAT_SESSION_ID, DEFAULT_CHAT_RATELIMIT_SESSION_ID } from "./constants"; +import { + DEFAULT_CHAT_SESSION_ID, + DEFAULT_CHAT_RATELIMIT_SESSION_ID, + DEFAULT_METADATA_KEY, + DEFAULT_SIMILARITY_THRESHOLD, + DEFAULT_TOP_K, + DEFAULT_HISTORY_LENGTH, + DEFAULT_HISTORY_TTL, +} from "./constants"; export const sanitizeQuestion = (question: string) => { return question.trim().replaceAll("\n", " "); @@ -24,8 +32,13 @@ export function appendDefaultsIfNeeded(options: ChatOptions) { return { ...options, sessionId: options.sessionId ?? DEFAULT_CHAT_SESSION_ID, + metadataKey: options.metadataKey ?? DEFAULT_METADATA_KEY, ratelimitSessionId: options.ratelimitSessionId ?? DEFAULT_CHAT_RATELIMIT_SESSION_ID, - } satisfies ChatOptions; + similarityThreshold: options.similarityThreshold ?? DEFAULT_SIMILARITY_THRESHOLD, + topK: options.topK ?? DEFAULT_TOP_K, + historyLength: options.historyLength ?? DEFAULT_HISTORY_LENGTH, + historyTTL: options.historyLength ?? DEFAULT_HISTORY_TTL, + }; } const DEFAULT_DELAY = 20_000;