diff --git a/lerna.json b/lerna.json index a57f61b29b..030dde848c 100644 --- a/lerna.json +++ b/lerna.json @@ -1,6 +1,6 @@ { "$schema": "node_modules/lerna/schemas/lerna-schema.json", - "version": "3.8.0", + "version": "3.8.1", "npmClient": "yarn", "concurrency": 20, "command": { diff --git a/packages/backend-core/src/cache/writethrough.ts b/packages/backend-core/src/cache/writethrough.ts index cd7409ca15..5a1d9f6a14 100644 --- a/packages/backend-core/src/cache/writethrough.ts +++ b/packages/backend-core/src/cache/writethrough.ts @@ -96,6 +96,24 @@ async function get(db: Database, id: string): Promise { return cacheItem.doc } +async function tryGet( + db: Database, + id: string +): Promise { + const cache = await getCache() + const cacheKey = makeCacheKey(db, id) + let cacheItem: CacheItem | null = await cache.get(cacheKey) + if (!cacheItem) { + const doc = await db.tryGet(id) + if (!doc) { + return null + } + cacheItem = makeCacheItem(doc) + await cache.store(cacheKey, cacheItem) + } + return cacheItem.doc +} + async function remove(db: Database, docOrId: any, rev?: any): Promise { const cache = await getCache() if (!docOrId) { @@ -123,10 +141,17 @@ export class Writethrough { return put(this.db, doc, writeRateMs) } + /** + * @deprecated use `tryGet` instead + */ async get(id: string) { return get(this.db, id) } + async tryGet(id: string) { + return tryGet(this.db, id) + } + async remove(docOrId: any, rev?: any) { return remove(this.db, docOrId, rev) } diff --git a/packages/backend-core/src/configs/configs.ts b/packages/backend-core/src/configs/configs.ts index f184bf87df..3747fff82e 100644 --- a/packages/backend-core/src/configs/configs.ts +++ b/packages/backend-core/src/configs/configs.ts @@ -47,6 +47,9 @@ export async function getConfig( export async function save( config: Config ): Promise<{ id: string; rev: string }> { + if (!config._id) { + config._id = generateConfigID(config.type) + } const db = context.getGlobalDB() return db.put(config) } diff --git a/packages/backend-core/src/configs/tests/configs.spec.ts b/packages/backend-core/src/configs/tests/configs.spec.ts index 2c6a1948ec..5b5186109c 100644 --- a/packages/backend-core/src/configs/tests/configs.spec.ts +++ b/packages/backend-core/src/configs/tests/configs.spec.ts @@ -12,7 +12,6 @@ describe("configs", () => { const setDbPlatformUrl = async (dbUrl: string) => { const settingsConfig = { - _id: configs.generateConfigID(ConfigType.SETTINGS), type: ConfigType.SETTINGS, config: { platformUrl: dbUrl, diff --git a/packages/backend-core/src/constants/db.ts b/packages/backend-core/src/constants/db.ts index 3085b91ef1..28d389e6ba 100644 --- a/packages/backend-core/src/constants/db.ts +++ b/packages/backend-core/src/constants/db.ts @@ -60,6 +60,11 @@ export const StaticDatabases = { SCIM_LOGS: { name: "scim-logs", }, + // Used by self-host users making use of Budicloud resources. Introduced when + // we started letting self-host users use Budibase AI in the cloud. + SELF_HOST_CLOUD: { + name: "self-host-cloud", + }, } export const APP_PREFIX = prefixed(DocumentType.APP) diff --git a/packages/backend-core/src/context/mainContext.ts b/packages/backend-core/src/context/mainContext.ts index 8e0c71ff18..e701f111aa 100644 --- a/packages/backend-core/src/context/mainContext.ts +++ b/packages/backend-core/src/context/mainContext.ts @@ -157,6 +157,33 @@ export async function doInTenant( return newContext(updates, task) } +// We allow self-host licensed users to make use of some Budicloud services +// (e.g. Budibase AI). When they do this, they use their license key as an API +// key. We use that license key to identify the tenant ID, and we set the +// context to be self-host using cloud. This affects things like where their +// quota documents get stored (because we want to avoid creating a new global +// DB for each self-host tenant). +export async function doInSelfHostTenantUsingCloud( + tenantId: string, + task: () => T +): Promise { + const updates = { tenantId, isSelfHostUsingCloud: true } + return newContext(updates, task) +} + +export function isSelfHostUsingCloud() { + const context = Context.get() + return !!context?.isSelfHostUsingCloud +} + +export function getSelfHostCloudDB() { + const context = Context.get() + if (!context || !context.isSelfHostUsingCloud) { + throw new Error("Self-host cloud DB not found") + } + return getDB(StaticDatabases.SELF_HOST_CLOUD.name) +} + export async function doInAppContext( appId: string, task: () => T @@ -325,6 +352,11 @@ export function getGlobalDB(): Database { if (!context || (env.MULTI_TENANCY && !context.tenantId)) { throw new Error("Global DB not found") } + if (context.isSelfHostUsingCloud) { + throw new Error( + "Global DB not found - self-host users using cloud don't have a global DB" + ) + } return getDB(baseGlobalDBName(context?.tenantId)) } @@ -344,6 +376,11 @@ export function getAppDB(opts?: any): Database { if (!appId) { throw new Error("Unable to retrieve app DB - no app ID.") } + if (isSelfHostUsingCloud()) { + throw new Error( + "App DB not found - self-host users using cloud don't have app DBs" + ) + } return getDB(appId, opts) } diff --git a/packages/backend-core/src/context/types.ts b/packages/backend-core/src/context/types.ts index 23598b951e..adee495e60 100644 --- a/packages/backend-core/src/context/types.ts +++ b/packages/backend-core/src/context/types.ts @@ -5,6 +5,7 @@ import { GoogleSpreadsheet } from "google-spreadsheet" // keep this out of Budibase types, don't want to expose context info export type ContextMap = { tenantId?: string + isSelfHostUsingCloud?: boolean appId?: string identity?: IdentityContext environmentVariables?: Record diff --git a/packages/pro b/packages/pro index 40c36f8658..8eb981cf01 160000 --- a/packages/pro +++ b/packages/pro @@ -1 +1 @@ -Subproject commit 40c36f86584568d31abd6dd5b6b00dd3a458093f +Subproject commit 8eb981cf01151261697a8f26c08c4c28f66b8e15 diff --git a/packages/server/package.json b/packages/server/package.json index e9bf4bbf15..18b13cba90 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -49,6 +49,7 @@ "author": "Budibase", "license": "GPL-3.0", "dependencies": { + "@anthropic-ai/sdk": "^0.27.3", "@apidevtools/swagger-parser": "10.0.3", "@aws-sdk/client-dynamodb": "3.709.0", "@aws-sdk/client-s3": "3.709.0", diff --git a/packages/server/scripts/dev/manage.js b/packages/server/scripts/dev/manage.js index a07fa1b582..a5af8650ef 100644 --- a/packages/server/scripts/dev/manage.js +++ b/packages/server/scripts/dev/manage.js @@ -47,6 +47,7 @@ async function init() { VERSION: "0.0.0+local", PASSWORD_MIN_LENGTH: "1", OPENAI_API_KEY: "sk-abcdefghijklmnopqrstuvwxyz1234567890abcd", + BUDICLOUD_URL: "https://budibaseqa.app", } config = { ...config, ...existingConfig } diff --git a/packages/server/src/api/routes/tests/ai.spec.ts b/packages/server/src/api/routes/tests/ai.spec.ts new file mode 100644 index 0000000000..ad2ae7dc50 --- /dev/null +++ b/packages/server/src/api/routes/tests/ai.spec.ts @@ -0,0 +1,344 @@ +import { mockChatGPTResponse } from "../../../tests/utilities/mocks/ai/openai" +import TestConfiguration from "../../../tests/utilities/TestConfiguration" +import nock from "nock" +import { configs, env, features, setEnv } from "@budibase/backend-core" +import { + AIInnerConfig, + ConfigType, + License, + PlanModel, + PlanType, + ProviderConfig, +} from "@budibase/types" +import { context } from "@budibase/backend-core" +import { mocks } from "@budibase/backend-core/tests" +import { MockLLMResponseFn } from "../../../tests/utilities/mocks/ai" +import { mockAnthropicResponse } from "../../../tests/utilities/mocks/ai/anthropic" + +function dedent(str: string) { + return str + .split("\n") + .map(line => line.trim()) + .join("\n") +} + +type SetupFn = ( + config: TestConfiguration +) => Promise<() => Promise | void> +interface TestSetup { + name: string + setup: SetupFn + mockLLMResponse: MockLLMResponseFn +} + +function budibaseAI(): SetupFn { + return async () => { + const cleanup = setEnv({ + OPENAI_API_KEY: "test-key", + }) + mocks.licenses.useBudibaseAI() + return async () => { + mocks.licenses.useCloudFree() + cleanup() + } + } +} + +function customAIConfig(providerConfig: Partial): SetupFn { + return async (config: TestConfiguration) => { + mocks.licenses.useAICustomConfigs() + + const innerConfig: AIInnerConfig = { + myaiconfig: { + provider: "OpenAI", + name: "OpenAI", + apiKey: "test-key", + defaultModel: "gpt-4o-mini", + active: true, + isDefault: true, + ...providerConfig, + }, + } + + const { id, rev } = await config.doInTenant( + async () => + await configs.save({ + type: ConfigType.AI, + config: innerConfig, + }) + ) + + return async () => { + mocks.licenses.useCloudFree() + + await config.doInTenant(async () => { + const db = context.getGlobalDB() + await db.remove(id, rev) + }) + } + } +} + +const allProviders: TestSetup[] = [ + { + name: "OpenAI API key", + setup: async () => { + return setEnv({ + OPENAI_API_KEY: "test-key", + }) + }, + mockLLMResponse: mockChatGPTResponse, + }, + { + name: "OpenAI API key with custom config", + setup: customAIConfig({ provider: "OpenAI", defaultModel: "gpt-4o-mini" }), + mockLLMResponse: mockChatGPTResponse, + }, + { + name: "Anthropic API key with custom config", + setup: customAIConfig({ + provider: "Anthropic", + defaultModel: "claude-3-5-sonnet-20240620", + }), + mockLLMResponse: mockAnthropicResponse, + }, + { + name: "BudibaseAI", + setup: budibaseAI(), + mockLLMResponse: mockChatGPTResponse, + }, +] + +describe("AI", () => { + const config = new TestConfiguration() + + beforeAll(async () => { + await config.init() + }) + + afterAll(() => { + config.end() + }) + + beforeEach(() => { + nock.cleanAll() + }) + + describe.each(allProviders)( + "provider: $name", + ({ setup, mockLLMResponse }: TestSetup) => { + let cleanup: () => Promise | void + beforeAll(async () => { + cleanup = await setup(config) + }) + + afterAll(async () => { + const maybePromise = cleanup() + if (maybePromise) { + await maybePromise + } + }) + + describe("POST /api/ai/js", () => { + let cleanup: () => void + beforeAll(() => { + cleanup = features.testutils.setFeatureFlags("*", { + AI_JS_GENERATION: true, + }) + }) + + afterAll(() => { + cleanup() + }) + + it("handles correct plain code response", async () => { + mockLLMResponse(`return 42`) + + const { code } = await config.api.ai.generateJs({ prompt: "test" }) + expect(code).toBe("return 42") + }) + + it("handles correct markdown code response", async () => { + mockLLMResponse( + dedent(` + \`\`\`js + return 42 + \`\`\` + `) + ) + + const { code } = await config.api.ai.generateJs({ prompt: "test" }) + expect(code).toBe("return 42") + }) + + it("handles multiple markdown code blocks returned", async () => { + mockLLMResponse( + dedent(` + This: + + \`\`\`js + return 42 + \`\`\` + + Or this: + + \`\`\`js + return 10 + \`\`\` + `) + ) + + const { code } = await config.api.ai.generateJs({ prompt: "test" }) + expect(code).toBe("return 42") + }) + + // TODO: handle when this happens + it.skip("handles no code response", async () => { + mockLLMResponse("I'm sorry, you're quite right, etc.") + const { code } = await config.api.ai.generateJs({ prompt: "test" }) + expect(code).toBe("") + }) + + it("handles LLM errors", async () => { + mockLLMResponse(() => { + throw new Error("LLM error") + }) + await config.api.ai.generateJs({ prompt: "test" }, { status: 500 }) + }) + }) + + describe("POST /api/ai/cron", () => { + it("handles correct cron response", async () => { + mockLLMResponse("0 0 * * *") + + const { message } = await config.api.ai.generateCron({ + prompt: "test", + }) + expect(message).toBe("0 0 * * *") + }) + + it("handles expected LLM error", async () => { + mockLLMResponse("Error generating cron: skill issue") + + await config.api.ai.generateCron( + { + prompt: "test", + }, + { status: 400 } + ) + }) + + it("handles unexpected LLM error", async () => { + mockLLMResponse(() => { + throw new Error("LLM error") + }) + + await config.api.ai.generateCron( + { + prompt: "test", + }, + { status: 500 } + ) + }) + }) + } + ) +}) + +describe("BudibaseAI", () => { + const config = new TestConfiguration() + let cleanup: () => void | Promise + beforeAll(async () => { + await config.init() + cleanup = await budibaseAI()(config) + }) + + afterAll(async () => { + if ("then" in cleanup) { + await cleanup() + } else { + cleanup() + } + config.end() + }) + + describe("POST /api/ai/chat", () => { + let envCleanup: () => void + let featureCleanup: () => void + beforeAll(() => { + envCleanup = setEnv({ SELF_HOSTED: false }) + featureCleanup = features.testutils.setFeatureFlags("*", { + AI_JS_GENERATION: true, + }) + }) + + afterAll(() => { + featureCleanup() + envCleanup() + }) + + beforeEach(() => { + nock.cleanAll() + const license: License = { + plan: { + type: PlanType.FREE, + model: PlanModel.PER_USER, + usesInvoicing: false, + }, + features: [], + quotas: {} as any, + tenantId: config.tenantId, + } + nock(env.ACCOUNT_PORTAL_URL).get("/api/license").reply(200, license) + }) + + it("handles correct chat response", async () => { + mockChatGPTResponse("Hi there!") + const { message } = await config.api.ai.chat({ + messages: [{ role: "user", content: "Hello!" }], + licenseKey: "test-key", + }) + expect(message).toBe("Hi there!") + }) + + it("handles chat response error", async () => { + mockChatGPTResponse(() => { + throw new Error("LLM error") + }) + await config.api.ai.chat( + { + messages: [{ role: "user", content: "Hello!" }], + licenseKey: "test-key", + }, + { status: 500 } + ) + }) + + it("handles no license", async () => { + nock.cleanAll() + nock(env.ACCOUNT_PORTAL_URL).get("/api/license").reply(404) + await config.api.ai.chat( + { + messages: [{ role: "user", content: "Hello!" }], + licenseKey: "test-key", + }, + { + status: 403, + } + ) + }) + + it("handles no license key", async () => { + await config.api.ai.chat( + { + messages: [{ role: "user", content: "Hello!" }], + // @ts-expect-error - intentionally wrong + licenseKey: undefined, + }, + { + status: 403, + } + ) + }) + }) +}) diff --git a/packages/server/src/api/routes/tests/row.spec.ts b/packages/server/src/api/routes/tests/row.spec.ts index c55db8640c..3fb882ff2f 100644 --- a/packages/server/src/api/routes/tests/row.spec.ts +++ b/packages/server/src/api/routes/tests/row.spec.ts @@ -46,7 +46,7 @@ import { withEnv } from "../../../environment" import { JsTimeoutError } from "@budibase/string-templates" import { isDate } from "../../../utilities" import nock from "nock" -import { mockChatGPTResponse } from "../../../tests/utilities/mocks/openai" +import { mockChatGPTResponse } from "../../../tests/utilities/mocks/ai/openai" const timestamp = new Date("2023-01-26T11:48:57.597Z").toISOString() tk.freeze(timestamp) diff --git a/packages/server/src/api/routes/tests/search.spec.ts b/packages/server/src/api/routes/tests/search.spec.ts index e115297ee9..7a7f388a2c 100644 --- a/packages/server/src/api/routes/tests/search.spec.ts +++ b/packages/server/src/api/routes/tests/search.spec.ts @@ -44,7 +44,7 @@ import { generator, structures, mocks } from "@budibase/backend-core/tests" import { DEFAULT_EMPLOYEE_TABLE_SCHEMA } from "../../../db/defaultData/datasource_bb_default" import { generateRowIdField } from "../../../integrations/utils" import { cloneDeep } from "lodash/fp" -import { mockChatGPTResponse } from "../../../tests/utilities/mocks/openai" +import { mockChatGPTResponse } from "../../../tests/utilities/mocks/ai/openai" const descriptions = datasourceDescribe({ plus: true }) diff --git a/packages/server/src/api/routes/tests/viewV2.spec.ts b/packages/server/src/api/routes/tests/viewV2.spec.ts index ad41aa618c..bca7d16807 100644 --- a/packages/server/src/api/routes/tests/viewV2.spec.ts +++ b/packages/server/src/api/routes/tests/viewV2.spec.ts @@ -41,7 +41,7 @@ import { datasourceDescribe } from "../../../integrations/tests/utils" import merge from "lodash/merge" import { quotas } from "@budibase/pro" import { context, db, events, roles, setEnv } from "@budibase/backend-core" -import { mockChatGPTResponse } from "../../../tests/utilities/mocks/openai" +import { mockChatGPTResponse } from "../../../tests/utilities/mocks/ai/openai" import nock from "nock" const descriptions = datasourceDescribe({ plus: true }) diff --git a/packages/server/src/automations/tests/steps/openai.spec.ts b/packages/server/src/automations/tests/steps/openai.spec.ts index a06c633e5e..3ad03eb1b2 100644 --- a/packages/server/src/automations/tests/steps/openai.spec.ts +++ b/packages/server/src/automations/tests/steps/openai.spec.ts @@ -1,11 +1,8 @@ import { createAutomationBuilder } from "../utilities/AutomationTestBuilder" -import { setEnv as setCoreEnv } from "@budibase/backend-core" +import { setEnv as setCoreEnv, withEnv } from "@budibase/backend-core" import { Model, MonthlyQuotaName, QuotaUsageType } from "@budibase/types" import TestConfiguration from "../../..//tests/utilities/TestConfiguration" -import { - mockChatGPTError, - mockChatGPTResponse, -} from "../../../tests/utilities/mocks/openai" +import { mockChatGPTResponse } from "../../../tests/utilities/mocks/ai/openai" import nock from "nock" import { mocks } from "@budibase/backend-core/tests" import { quotas } from "@budibase/pro" @@ -83,7 +80,9 @@ describe("test the openai action", () => { }) it("should present the correct error message when an error is thrown from the createChatCompletion call", async () => { - mockChatGPTError() + mockChatGPTResponse(() => { + throw new Error("oh no") + }) const result = await expectAIUsage(0, () => createAutomationBuilder(config) @@ -108,11 +107,13 @@ describe("test the openai action", () => { // path, because we've enabled Budibase AI. The exact value depends on a // calculation we use to approximate cost. This uses Budibase's OpenAI API // key, so we charge users for it. - const result = await expectAIUsage(14, () => - createAutomationBuilder(config) - .onAppAction() - .openai({ model: Model.GPT_4O_MINI, prompt: "Hello, world" }) - .test({ fields: {} }) + const result = await withEnv({ SELF_HOSTED: false }, () => + expectAIUsage(14, () => + createAutomationBuilder(config) + .onAppAction() + .openai({ model: Model.GPT_4O_MINI, prompt: "Hello, world" }) + .test({ fields: {} }) + ) ) expect(result.steps[0].outputs.response).toEqual("This is a test") diff --git a/packages/server/src/constants/screens.ts b/packages/server/src/constants/screens.ts index 41c1e74874..3a7413633d 100644 --- a/packages/server/src/constants/screens.ts +++ b/packages/server/src/constants/screens.ts @@ -365,7 +365,11 @@ export function createSampleDataTableScreen(): Screen { _component: "@budibase/standard-components/textv2", _styles: { normal: { + "--grid-desktop-col-start": 1, "--grid-desktop-col-end": 3, + "--grid-desktop-row-start": 1, + "--grid-desktop-row-end": 3, + "--grid-mobile-col-end": 7, }, hover: {}, active: {}, @@ -384,6 +388,7 @@ export function createSampleDataTableScreen(): Screen { "--grid-desktop-row-start": 1, "--grid-desktop-row-end": 3, "--grid-desktop-h-align": "end", + "--grid-mobile-col-start": 7, }, hover: {}, active: {}, diff --git a/packages/server/src/tests/utilities/api/ai.ts b/packages/server/src/tests/utilities/api/ai.ts new file mode 100644 index 0000000000..efaa321f09 --- /dev/null +++ b/packages/server/src/tests/utilities/api/ai.ts @@ -0,0 +1,47 @@ +import { + ChatCompletionRequest, + ChatCompletionResponse, + GenerateCronRequest, + GenerateCronResponse, + GenerateJsRequest, + GenerateJsResponse, +} from "@budibase/types" +import { Expectations, TestAPI } from "./base" +import { constants } from "@budibase/backend-core" + +export class AIAPI extends TestAPI { + generateJs = async ( + req: GenerateJsRequest, + expectations?: Expectations + ): Promise => { + return await this._post(`/api/ai/js`, { + body: req, + expectations, + }) + } + + generateCron = async ( + req: GenerateCronRequest, + expectations?: Expectations + ): Promise => { + return await this._post(`/api/ai/cron`, { + body: req, + expectations, + }) + } + + chat = async ( + req: ChatCompletionRequest & { licenseKey: string }, + expectations?: Expectations + ): Promise => { + const headers: Record = {} + if (req.licenseKey) { + headers[constants.Header.LICENSE_KEY] = req.licenseKey + } + return await this._post(`/api/ai/chat`, { + body: req, + headers, + expectations, + }) + } +} diff --git a/packages/server/src/tests/utilities/api/index.ts b/packages/server/src/tests/utilities/api/index.ts index ba99c2eca0..9c00b77b73 100644 --- a/packages/server/src/tests/utilities/api/index.ts +++ b/packages/server/src/tests/utilities/api/index.ts @@ -22,8 +22,10 @@ import { UserPublicAPI } from "./public/user" import { MiscAPI } from "./misc" import { OAuth2API } from "./oauth2" import { AssetsAPI } from "./assets" +import { AIAPI } from "./ai" export default class API { + ai: AIAPI application: ApplicationAPI attachment: AttachmentAPI automation: AutomationAPI @@ -52,6 +54,7 @@ export default class API { } constructor(config: TestConfiguration) { + this.ai = new AIAPI(config) this.application = new ApplicationAPI(config) this.attachment = new AttachmentAPI(config) this.automation = new AutomationAPI(config) diff --git a/packages/server/src/tests/utilities/mocks/ai/anthropic.ts b/packages/server/src/tests/utilities/mocks/ai/anthropic.ts new file mode 100644 index 0000000000..ff0413aee1 --- /dev/null +++ b/packages/server/src/tests/utilities/mocks/ai/anthropic.ts @@ -0,0 +1,48 @@ +import AnthropicClient from "@anthropic-ai/sdk" +import nock from "nock" +import { MockLLMResponseFn, MockLLMResponseOpts } from "." + +let chatID = 1 +const SPACE_REGEX = /\s+/g + +export const mockAnthropicResponse: MockLLMResponseFn = ( + answer: string | ((prompt: string) => string), + opts?: MockLLMResponseOpts +) => { + return nock(opts?.host || "https://api.anthropic.com") + .post("/v1/messages") + .reply((uri: string, body: nock.Body) => { + const req = body as AnthropicClient.MessageCreateParamsNonStreaming + const prompt = req.messages[0].content + if (typeof prompt !== "string") { + throw new Error("Anthropic mock only supports string prompts") + } + + let content + if (typeof answer === "function") { + try { + content = answer(prompt) + } catch (e) { + return [500, "Internal Server Error"] + } + } else { + content = answer + } + + const resp: AnthropicClient.Messages.Message = { + id: `${chatID++}`, + type: "message", + role: "assistant", + model: req.model, + stop_reason: "end_turn", + usage: { + input_tokens: prompt.split(SPACE_REGEX).length, + output_tokens: content.split(SPACE_REGEX).length, + }, + stop_sequence: null, + content: [{ type: "text", text: content }], + } + return [200, resp] + }) + .persist() +} diff --git a/packages/server/src/tests/utilities/mocks/ai/index.ts b/packages/server/src/tests/utilities/mocks/ai/index.ts new file mode 100644 index 0000000000..87f8ce77be --- /dev/null +++ b/packages/server/src/tests/utilities/mocks/ai/index.ts @@ -0,0 +1,10 @@ +import { Scope } from "nock" + +export interface MockLLMResponseOpts { + host?: string +} + +export type MockLLMResponseFn = ( + answer: string | ((prompt: string) => string), + opts?: MockLLMResponseOpts +) => Scope diff --git a/packages/server/src/tests/utilities/mocks/openai.ts b/packages/server/src/tests/utilities/mocks/ai/openai.ts similarity index 81% rename from packages/server/src/tests/utilities/mocks/openai.ts rename to packages/server/src/tests/utilities/mocks/ai/openai.ts index 7fcc0c08fc..827caad9be 100644 --- a/packages/server/src/tests/utilities/mocks/openai.ts +++ b/packages/server/src/tests/utilities/mocks/ai/openai.ts @@ -1,12 +1,9 @@ import nock from "nock" +import { MockLLMResponseFn, MockLLMResponseOpts } from "." let chatID = 1 const SPACE_REGEX = /\s+/g -interface MockChatGPTResponseOpts { - host?: string -} - interface Message { role: string content: string @@ -47,19 +44,24 @@ interface ChatCompletionResponse { usage: Usage } -export function mockChatGPTResponse( +export const mockChatGPTResponse: MockLLMResponseFn = ( answer: string | ((prompt: string) => string), - opts?: MockChatGPTResponseOpts -) { + opts?: MockLLMResponseOpts +) => { return nock(opts?.host || "https://api.openai.com") .post("/v1/chat/completions") - .reply(200, (uri: string, requestBody: ChatCompletionRequest) => { - const messages = requestBody.messages + .reply((uri: string, body: nock.Body) => { + const req = body as ChatCompletionRequest + const messages = req.messages const prompt = messages[0].content let content if (typeof answer === "function") { - content = answer(prompt) + try { + content = answer(prompt) + } catch (e) { + return [500, "Internal Server Error"] + } } else { content = answer } @@ -76,7 +78,7 @@ export function mockChatGPTResponse( id: `chatcmpl-${chatID}`, object: "chat.completion", created: Math.floor(Date.now() / 1000), - model: requestBody.model, + model: req.model, system_fingerprint: `fp_${chatID}`, choices: [ { @@ -97,14 +99,7 @@ export function mockChatGPTResponse( }, }, } - return response + return [200, response] }) .persist() } - -export function mockChatGPTError() { - return nock("https://api.openai.com") - .post("/v1/chat/completions") - .reply(500, "Internal Server Error") - .persist() -} diff --git a/packages/types/src/api/web/ai.ts b/packages/types/src/api/web/ai.ts index 3962422b77..f9c587ca0b 100644 --- a/packages/types/src/api/web/ai.ts +++ b/packages/types/src/api/web/ai.ts @@ -1,5 +1,18 @@ import { EnrichedBinding } from "../../ui" +export interface Message { + role: "system" | "user" + content: string +} + +export interface ChatCompletionRequest { + messages: Message[] +} + +export interface ChatCompletionResponse { + message?: string +} + export interface GenerateJsRequest { prompt: string bindings?: EnrichedBinding[] @@ -8,3 +21,11 @@ export interface GenerateJsRequest { export interface GenerateJsResponse { code: string } + +export interface GenerateCronRequest { + prompt: string +} + +export interface GenerateCronResponse { + message?: string +} diff --git a/packages/types/src/documents/global/config.ts b/packages/types/src/documents/global/config.ts index bd0340595c..422486e30f 100644 --- a/packages/types/src/documents/global/config.ts +++ b/packages/types/src/documents/global/config.ts @@ -117,6 +117,7 @@ export type AIProvider = | "AzureOpenAI" | "TogetherAI" | "Custom" + | "BudibaseAI" export interface ProviderConfig { provider: AIProvider