From 14eca0beb3294a7f6455f6f39de90cd8c34aabb6 Mon Sep 17 00:00:00 2001 From: xiang <1984871009@qq.com> Date: Fri, 5 May 2023 09:57:42 +0800 Subject: [PATCH] refactor: add chat model factory --- index.ts | 25 +++++++++++++++++------- model/base.ts | 39 +++++++++++++++++++++++++++++++++++++ model/index.ts | 48 +++++++++++++++++----------------------------- model/you/index.ts | 4 ++-- 4 files changed, 77 insertions(+), 39 deletions(-) create mode 100644 model/base.ts diff --git a/index.ts b/index.ts index f437515..62f2c8f 100644 --- a/index.ts +++ b/index.ts @@ -1,39 +1,50 @@ -import {You} from "./model/you"; import Koa from 'koa'; import Router from 'koa-router' import bodyParser from 'koa-bodyparser'; +import {ChatModelFactory, Model} from "./model"; const app = new Koa(); const router = new Router(); app.use(bodyParser()); -const you = new You({proxy: process.env.https_proxy || process.env.http_proxy}); +const chatModel = new ChatModelFactory({proxy: process.env.https_proxy || process.env.http_proxy}); interface AskReq { prompt: string; + model: Model; } router.get('/ask', async (ctx) => { - const {prompt} = ctx.query; + const {prompt, model = Model.You} = ctx.query as unknown as AskReq; if (!prompt) { ctx.body = 'please input prompt'; return; } - const res = await you.ask({prompt: prompt as string}); + const chat = chatModel.get(model); + if (!chat) { + ctx.body = 'Unsupported model'; + return; + } + const res = await chat.ask({prompt: prompt as string}); ctx.body = res.text; }); router.get('/ask/stream', async (ctx) => { - const {prompt} = ctx.query; + const {prompt, model = Model.You} = ctx.query as unknown as AskReq; if (!prompt) { ctx.body = 'please input prompt'; return; } + const chat = chatModel.get(model); + if (!chat) { + ctx.body = 'Unsupported model'; + return; + } ctx.set({ "Content-Type": "text/event-stream", "Cache-Control": "no-cache", "Connection": "keep-alive", - }) - const res = await you.askStream({prompt: prompt as string}); + }); + const res = await chat.askStream({prompt: prompt as string}); ctx.body = res.text; }) diff --git a/model/base.ts b/model/base.ts new file mode 100644 index 0000000..831b9c0 --- /dev/null +++ b/model/base.ts @@ -0,0 +1,39 @@ +import {Stream} from "stream"; + +export interface ChatOptions { + proxy?: string; +} + +export interface Response { + text: string | null; + other: any; +} + +export interface ResponseStream { + text: Stream; + other: any; +} + +export interface Request { + prompt: string; + history?: HistoryItem[]; + options?: any; +} + +export interface HistoryItem { + question?: string; + answer?: string; +} + + +export abstract class Chat { + protected proxy: string | undefined; + + constructor(options?: ChatOptions) { + this.proxy = options?.proxy; + } + + public abstract ask(req: Request): Promise + + public abstract askStream(req: Request): Promise +} diff --git a/model/index.ts b/model/index.ts index 831b9c0..b2f40f4 100644 --- a/model/index.ts +++ b/model/index.ts @@ -1,39 +1,27 @@ -import {Stream} from "stream"; +import {Chat, ChatOptions} from "./base"; +import {You} from "./you"; -export interface ChatOptions { - proxy?: string; +export enum Model { + // define new model here + You = 'you', } -export interface Response { - text: string | null; - other: any; -} - -export interface ResponseStream { - text: Stream; - other: any; -} - -export interface Request { - prompt: string; - history?: HistoryItem[]; - options?: any; -} - -export interface HistoryItem { - question?: string; - answer?: string; -} - - -export abstract class Chat { - protected proxy: string | undefined; +export class ChatModelFactory { + private modelMap: Map; + private readonly options: ChatOptions | undefined; constructor(options?: ChatOptions) { - this.proxy = options?.proxy; + this.modelMap = new Map(); + this.options = options; + this.init(); } - public abstract ask(req: Request): Promise + init() { + // register new model here + this.modelMap.set(Model.You, new You(this.options)) + } - public abstract askStream(req: Request): Promise + get(model: Model): Chat | undefined { + return this.modelMap.get(model); + } } diff --git a/model/you/index.ts b/model/you/index.ts index c448c83..d3bce7b 100644 --- a/model/you/index.ts +++ b/model/you/index.ts @@ -5,7 +5,7 @@ import tlsClient from 'tls-client'; import {Session} from "tls-client/dist/esm/sessions"; import {Params} from "tls-client/dist/esm/types"; import {toEventCB, toEventStream} from "../../utils"; -import {Chat, ChatOptions, Request, Response, ResponseStream} from "../index"; +import {Chat, ChatOptions, Request, Response, ResponseStream} from "../base"; const userAgent = new UserAgent(); @@ -65,7 +65,7 @@ interface SearchResult { export class You extends Chat { private session: Session; - constructor(props: ChatOptions) { + constructor(props?: ChatOptions) { super(props); this.session = new tlsClient.Session({clientIdentifier: 'chrome_108'}); this.session.headers = this.getHeaders();