diff --git a/packages/server/src/api/routes/tests/ai.spec.ts b/packages/server/src/api/routes/tests/ai.spec.ts index 9e041a619e..0afd0e274f 100644 --- a/packages/server/src/api/routes/tests/ai.spec.ts +++ b/packages/server/src/api/routes/tests/ai.spec.ts @@ -1,3 +1,4 @@ +import { z } from "zod" import { mockChatGPTResponse } from "../../../tests/utilities/mocks/ai/openai" import TestConfiguration from "../../../tests/utilities/TestConfiguration" import nock from "nock" @@ -10,12 +11,13 @@ import { PlanModel, PlanType, ProviderConfig, + StructuredOutput, } from "@budibase/types" import { context } from "@budibase/backend-core" -import { mocks } from "@budibase/backend-core/tests" +import { generator, mocks } from "@budibase/backend-core/tests" +import { ai, quotas } from "@budibase/pro" import { MockLLMResponseFn } from "../../../tests/utilities/mocks/ai" import { mockAnthropicResponse } from "../../../tests/utilities/mocks/ai/anthropic" -import { quotas } from "@budibase/pro" function dedent(str: string) { return str @@ -285,7 +287,8 @@ describe("BudibaseAI", () => { envCleanup() }) - beforeEach(() => { + beforeEach(async () => { + await config.newTenant() nock.cleanAll() const license: License = { plan: { @@ -366,5 +369,66 @@ describe("BudibaseAI", () => { } ) }) + + it("handles text format", async () => { + let usage = await getQuotaUsage() + expect(usage._id).toBe(`quota_usage_${config.getTenantId()}`) + expect(usage.monthly.current.budibaseAICredits).toBe(0) + + const gptResponse = generator.word() + mockChatGPTResponse(gptResponse, { format: "text" }) + const { message } = await config.api.ai.chat({ + messages: [{ role: "user", content: "Hello!" }], + format: "text", + licenseKey: licenseKey, + }) + expect(message).toBe(gptResponse) + + usage = await getQuotaUsage() + expect(usage.monthly.current.budibaseAICredits).toBeGreaterThan(0) + }) + + it("handles json format", async () => { + let usage = await getQuotaUsage() + expect(usage._id).toBe(`quota_usage_${config.getTenantId()}`) + expect(usage.monthly.current.budibaseAICredits).toBe(0) + + const gptResponse = JSON.stringify({ + [generator.word()]: generator.word(), + }) + mockChatGPTResponse(gptResponse, { format: "json" }) + const { message } = await config.api.ai.chat({ + messages: [{ role: "user", content: "Hello!" }], + format: "json", + licenseKey: licenseKey, + }) + expect(message).toBe(gptResponse) + + usage = await getQuotaUsage() + expect(usage.monthly.current.budibaseAICredits).toBeGreaterThan(0) + }) + + it("handles structured outputs", async () => { + let usage = await getQuotaUsage() + expect(usage._id).toBe(`quota_usage_${config.getTenantId()}`) + expect(usage.monthly.current.budibaseAICredits).toBe(0) + + const gptResponse = generator.guid() + const structuredOutput = generator.word() as unknown as StructuredOutput + ai.structuredOutputs[structuredOutput] = { + key: generator.word(), + validator: z.object({ name: z.string() }), + } + mockChatGPTResponse(gptResponse, { format: structuredOutput }) + const { message } = await config.api.ai.chat({ + messages: [{ role: "user", content: "Hello!" }], + format: structuredOutput, + licenseKey: licenseKey, + }) + expect(message).toBe(gptResponse) + + usage = await getQuotaUsage() + expect(usage.monthly.current.budibaseAICredits).toBeGreaterThan(0) + }) }) }) diff --git a/packages/server/src/tests/utilities/mocks/ai/index.ts b/packages/server/src/tests/utilities/mocks/ai/index.ts index 87f8ce77be..d7df6be44f 100644 --- a/packages/server/src/tests/utilities/mocks/ai/index.ts +++ b/packages/server/src/tests/utilities/mocks/ai/index.ts @@ -1,7 +1,9 @@ +import { ResponseFormat } from "@budibase/types" import { Scope } from "nock" export interface MockLLMResponseOpts { host?: string + format?: ResponseFormat } export type MockLLMResponseFn = ( diff --git a/packages/server/src/tests/utilities/mocks/ai/openai.ts b/packages/server/src/tests/utilities/mocks/ai/openai.ts index 827caad9be..3a9ac7f87a 100644 --- a/packages/server/src/tests/utilities/mocks/ai/openai.ts +++ b/packages/server/src/tests/utilities/mocks/ai/openai.ts @@ -1,5 +1,7 @@ import nock from "nock" import { MockLLMResponseFn, MockLLMResponseOpts } from "." +import _ from "lodash" +import { ai } from "@budibase/pro" let chatID = 1 const SPACE_REGEX = /\s+/g @@ -48,8 +50,15 @@ export const mockChatGPTResponse: MockLLMResponseFn = ( answer: string | ((prompt: string) => string), opts?: MockLLMResponseOpts ) => { + let body: any = undefined + + if (opts?.format) { + body = _.matches({ + response_format: ai.openai.parseResponseFormat(opts.format), + }) + } return nock(opts?.host || "https://api.openai.com") - .post("/v1/chat/completions") + .post("/v1/chat/completions", body) .reply((uri: string, body: nock.Body) => { const req = body as ChatCompletionRequest const messages = req.messages diff --git a/packages/types/src/api/web/ai.ts b/packages/types/src/api/web/ai.ts index 3ae1b08f81..b84b648591 100644 --- a/packages/types/src/api/web/ai.ts +++ b/packages/types/src/api/web/ai.ts @@ -5,8 +5,13 @@ export interface Message { content: string } +export enum StructuredOutput {} + +export type ResponseFormat = "text" | "json" | StructuredOutput + export interface ChatCompletionRequest { messages: Message[] + format?: ResponseFormat } export interface ChatCompletionResponse {