gpt4free-ts/router.ts

404 行
11 KiB
TypeScript

import Koa, { Context, Middleware, Next } from 'koa';
import {
ChatRequest,
ChatResponse,
countMessagesToken,
Message,
ModelType,
Site,
} from './model/base';
import {
ClaudeEventStream,
ComError,
Event,
EventStream,
getTokenCount,
OpenaiEventStream,
parseJSON,
randomStr,
ThroughEventStream,
} from './utils';
import { ChatModelFactory } from './model';
import moment from 'moment/moment';
import cors from '@koa/cors';
import Router from 'koa-router';
import bodyParser from 'koa-bodyparser';
import { randomUUID } from 'crypto';
import { chatModel } from './model';
import { TraceLogger } from './utils/log';
import { end } from 'cheerio/lib/api/traversing';
import apm from 'elastic-apm-node';
const supportsHandler = async (ctx: Context) => {
const result: Support[] = [];
for (const key in Site) {
//@ts-ignore
const site = Site[key];
//@ts-ignore
const chat = chatModel.get(site);
const support: Support = { site: site, models: [] };
for (const mKey in ModelType) {
//@ts-ignore
const model = ModelType[mKey];
//@ts-ignore
if (chat?.support(model)) {
support.models.push(model);
}
}
result.push(support);
}
ctx.body = result;
};
const errorHandler = async (ctx: Context, next: Next) => {
try {
await next();
} catch (err: any) {
ctx.logger?.info(err.message, {
trace_label: 'error',
...(ctx.query as any),
...(ctx.request.body as any),
...(ctx.params as any),
});
ctx.body = { error: { message: err.message } };
ctx.status = err.status || ComError.Status.InternalServerError;
}
};
interface AskReq extends ChatRequest {
site: Site;
}
interface AskRes extends ChatResponse {}
async function checkApiKey(ctx: Context, next: Next) {
let secret = '';
const authorStr =
ctx.request.headers.authorization || ctx.request.headers['x-api-key'];
secret = ((authorStr as string) || '').replace(/Bearer /, '');
ctx.query = { ...ctx.query, secret };
if (!process.env.API_KEY) {
await next();
return;
}
if (secret !== process.env.API_KEY) {
throw new ComError('invalid api key', 401);
}
await next();
}
const AskHandle: Middleware = async (ctx) => {
const {
prompt,
model = ModelType.GPT3p5Turbo,
site = Site.You,
...rest
} = {
...(ctx.query as any),
...(ctx.request.body as any),
...(ctx.params as any),
} as AskReq;
if (model !== ModelType.GetGizmoInfo && !prompt) {
throw new ComError(`need prompt in query`, ComError.Status.BadRequest);
}
const chat = chatModel.get(site);
if (!chat) {
throw new ComError(`not support site: ${site} `, ComError.Status.NotFound);
}
let req: ChatRequest = {
...rest,
prompt,
messages: parseJSON<Message[]>(prompt, [{ role: 'user', content: prompt }]),
model,
};
if (typeof req.messages !== 'object') {
// 数值类型parseJSON后为number
req.messages = [{ role: 'user', content: prompt }];
}
req = await chat.preHandle(req);
const data = await chat.ask(req);
if (data && data.error) {
ctx.status = 500;
}
req.messages.push({ role: 'assistant', content: data.content || '' });
console.debug(req.messages);
ctx.body = data;
return req;
};
const AskStreamHandle: (ESType: new () => EventStream) => Middleware =
(ESType) => async (ctx) => {
const {
prompt,
model = ModelType.GPT3p5Turbo,
site = Site.You,
...rest
} = {
...(ctx.query as any),
...(ctx.request.body as any),
...(ctx.params as any),
} as AskReq;
apm.currentTransaction?.addLabels({ site, model }, true);
if (model !== ModelType.GetGizmoInfo && !prompt) {
throw new ComError(`need prompt in query`, ComError.Status.BadRequest);
}
const chat = chatModel.get(site);
if (!chat) {
throw new ComError(
`not support site: ${site} `,
ComError.Status.NotFound,
);
}
let req: ChatRequest = {
...rest,
prompt,
messages: parseJSON<Message[]>(prompt, [
{ role: 'user', content: prompt },
]),
model,
};
if (typeof req.messages !== 'object') {
req.messages = [{ role: 'user', content: prompt }];
}
let stream = new ESType();
stream.setModel(req.model);
req = await chat.preHandle(req, { stream });
ctx.logger.info('start', {
model,
req: ctx.req,
res: ctx.res,
trace_label: 'start',
});
let ok = true;
const timeout = setTimeout(() => {
stream.write(Event.error, { error: 'timeout' });
stream.write(Event.done, { content: '' });
stream.end();
}, 120 * 1000);
const input = req.messages;
let output = '';
return (() =>
new Promise<void>(async (resolve, reject) => {
try {
const es = new ThroughEventStream(
(event, data: any) => {
switch (event) {
case Event.error:
ctx.logger.info(data.error, {
req: ctx.req,
res: ctx.res,
trace_label: 'error',
});
clearTimeout(timeout);
if (data instanceof ComError) {
reject(data);
return;
}
ok = false;
reject(
new ComError(
(data as any)?.error || 'unknown error',
(data as any)?.status ||
ComError.Status.InternalServerError,
),
);
break;
default:
if (!ctx.body && !data.content) {
break;
}
clearTimeout(timeout);
if (!ok) {
break;
}
if (!ctx.body) {
ctx.set({
'Content-Type': 'text/event-stream;charset=utf-8',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});
ctx.body = stream.stream();
ctx.logger.info('recv', {
model,
req: ctx.req,
res: ctx.res,
trace_label: 'recv',
});
}
resolve();
stream.write(event, data);
output += (data as any).content || '';
break;
}
},
() => {
if (!ok) {
return;
}
input.push({ role: 'assistant', content: output });
delete (req as any).prompt;
ctx.logger.info(JSON.stringify(req), {
model,
req: ctx.req,
res: ctx.res,
trace_label: 'end',
});
stream.end();
},
);
await chat.askStream(req, es).catch((err) => {
clearTimeout(timeout);
es.destroy();
reject(err);
});
} catch (e) {
reject(e);
}
}))();
};
interface OpenAIReq {
site: Site;
stream: boolean;
model: ModelType;
messages: Message[];
}
interface ClaudeReq {
site: Site;
stream: boolean;
model: ModelType;
prompt: string;
}
interface Support {
site: string;
models: string[];
}
const openAIHandle: Middleware = async (ctx, next) => {
const { stream, messages, model } = {
...(ctx.query as any),
...(ctx.request.body as any),
...(ctx.params as any),
} as OpenAIReq;
(ctx.request.body as any).prompt = JSON.stringify(
(ctx.request.body as any).messages,
);
if (stream) {
await AskStreamHandle(OpenaiEventStream)(ctx, next);
return;
}
const req: ChatRequest = await AskHandle(ctx, next);
let reqLen = countMessagesToken(messages);
const tileSize = 512;
const tokensPerTile = 170;
for (const v of req.images || []) {
const tilesForWidth = Math.ceil(v.width / tileSize);
const tilesForHeight = Math.ceil(v.height / tileSize);
const totalTiles = tilesForWidth * tilesForHeight;
const totalTokens = 85 + tokensPerTile * totalTiles;
reqLen += totalTokens;
}
const completion_tokens = getTokenCount(ctx.body.content || '');
ctx.body = {
id: 'chatcmpl-' + '89D' + randomStr(26),
object: 'chat.completion',
created: moment().unix(),
model,
choices: [
{
index: 0,
message: {
role: 'assistant',
...ctx.body,
},
finish_reason: 'stop',
},
],
usage: {
// 官方默认所有请求token都+7
prompt_tokens: 7 + reqLen,
completion_tokens,
total_tokens: reqLen + completion_tokens,
},
};
};
const claudeHandle: Middleware = async (ctx, next) => {
const { stream, model } = {
...(ctx.query as any),
...(ctx.request.body as any),
...(ctx.params as any),
} as ClaudeReq;
if (stream) {
await AskStreamHandle(ClaudeEventStream)(ctx, next);
return;
}
await AskHandle(ctx, next);
ctx.body = {
completion: ctx.body.content,
stop_reason: 'stop_sequence',
model: model,
stop: '\n\nHuman:',
log_id: randomStr(64).toLowerCase(),
};
};
const audioHandle: Middleware = async (ctx, next) => {
const { site, ...req } = {
...(ctx.query as any),
...(ctx.request.body as any),
...(ctx.params as any),
} as any;
const chat = chatModel.get(site);
if (!chat) {
throw new ComError(`not support site: ${site} `, ComError.Status.NotFound);
}
await chat.speech(ctx, req);
};
const imageGenHandle: Middleware = async (ctx, next) => {
const { site, ...req } = {
...(ctx.query as any),
...(ctx.request.body as any),
...(ctx.params as any),
} as any;
const chat = chatModel.get(site);
if (!chat) {
throw new ComError(`not support site: ${site} `, ComError.Status.NotFound);
}
await chat.generations(ctx, req);
};
export const registerApp = () => {
const app = new Koa();
app.use(cors());
const router = new Router();
app.use(async (ctx, next) => {
ctx.logger = new TraceLogger();
await next();
});
app.use(errorHandler);
app.use(bodyParser({ jsonLimit: '10mb' }));
app.use(checkApiKey);
router.get('/supports', supportsHandler);
router.get('/ask', AskHandle);
router.post('/ask', AskHandle);
router.get('/ask/stream', AskStreamHandle(EventStream));
router.post('/ask/stream', AskStreamHandle(EventStream));
router.post('/v1/chat/completions', openAIHandle);
router.post('/:site/v1/chat/completions', openAIHandle);
router.post('/v1/complete', claudeHandle);
router.post('/:site/v1/complete', claudeHandle);
router.post('/v1/audio/speech', audioHandle);
router.post('/:site/v1/audio/speech', audioHandle);
router.post('/:site/v1/images/generations', imageGenHandle);
app.use(router.routes());
const port = +(process.env.PORT || 3000);
const server = app.listen(port, () => {
console.log(`Now listening: 127.0.0.1:${port}`);
});
console.log(`Worker ${process.pid} started`);
};