From b9fb4416bb428acab70288fdc7966454670b738f Mon Sep 17 00:00:00 2001 From: Sam Rose Date: Wed, 8 Jan 2025 13:18:40 +0000 Subject: [PATCH] Allow use of AI fields in view calculations. --- packages/pro | 2 +- .../src/api/controllers/row/staticFormula.ts | 6 ++ .../server/src/api/routes/tests/row.spec.ts | 40 ++++---- .../src/api/routes/tests/viewV2.spec.ts | 97 ++++++++++++++++++- .../src/tests/utilities/mocks/openai.ts | 46 +++++++++ .../src/utilities/rowProcessor/utils.ts | 2 +- packages/types/src/documents/app/row.ts | 1 + .../types/src/documents/app/table/schema.ts | 2 +- 8 files changed, 170 insertions(+), 26 deletions(-) create mode 100644 packages/server/src/tests/utilities/mocks/openai.ts diff --git a/packages/pro b/packages/pro index 32d84f109d..45f5b6fe9b 160000 --- a/packages/pro +++ b/packages/pro @@ -1 +1 @@ -Subproject commit 32d84f109d4edc526145472a7446327312151442 +Subproject commit 45f5b6fe9bbdbdf502581740ab43b82e8153260f diff --git a/packages/server/src/api/controllers/row/staticFormula.ts b/packages/server/src/api/controllers/row/staticFormula.ts index b81a164807..afa3a1f239 100644 --- a/packages/server/src/api/controllers/row/staticFormula.ts +++ b/packages/server/src/api/controllers/row/staticFormula.ts @@ -162,6 +162,12 @@ export async function finaliseRow( dynamic: false, contextRows: [enrichedRow], }) + + const flag1 = await features.isEnabled(FeatureFlag.BUDIBASE_AI) + const flag2 = await pro.features.isBudibaseAIEnabled() + const flag3 = await features.isEnabled(FeatureFlag.AI_CUSTOM_CONFIGS) + const flag4 = await pro.features.isAICustomConfigsEnabled() + const aiEnabled = ((await features.isEnabled(FeatureFlag.BUDIBASE_AI)) && (await pro.features.isBudibaseAIEnabled())) || diff --git a/packages/server/src/api/routes/tests/row.spec.ts b/packages/server/src/api/routes/tests/row.spec.ts index a3012c3760..968ce9c798 100644 --- a/packages/server/src/api/routes/tests/row.spec.ts +++ b/packages/server/src/api/routes/tests/row.spec.ts @@ -8,7 +8,13 @@ import { import tk from "timekeeper" import emitter from "../../../../src/events" import { outputProcessing } from "../../../utilities/rowProcessor" -import { context, InternalTable, tenancy, utils } from "@budibase/backend-core" +import { + context, + setEnv, + InternalTable, + tenancy, + utils, +} from "@budibase/backend-core" import { quotas } from "@budibase/pro" import { AIOperationEnum, @@ -42,19 +48,8 @@ import { InternalTables } from "../../../db/utils" import { withEnv } from "../../../environment" import { JsTimeoutError } from "@budibase/string-templates" import { isDate } from "../../../utilities" - -jest.mock("@budibase/pro", () => ({ - ...jest.requireActual("@budibase/pro"), - ai: { - LargeLanguageModel: { - forCurrentTenant: async () => ({ - llm: {}, - run: jest.fn(() => `Mock LLM Response`), - buildPromptFromAIOperation: jest.fn(), - }), - }, - }, -})) +import nock from "nock" +import { mockChatGPTResponse } from "../../../tests/utilities/mocks/openai" const timestamp = new Date("2023-01-26T11:48:57.597Z").toISOString() tk.freeze(timestamp) @@ -99,6 +94,8 @@ if (descriptions.length) { const ds = await dsProvider() datasource = ds.datasource client = ds.client + + mocks.licenses.useCloudFree() }) afterAll(async () => { @@ -172,10 +169,6 @@ if (descriptions.length) { ) } - beforeEach(async () => { - mocks.licenses.useCloudFree() - }) - const getRowUsage = async () => { const { total } = await config.doInContext(undefined, () => quotas.getCurrentUsageValues( @@ -3224,10 +3217,17 @@ if (descriptions.length) { isInternal && describe("AI fields", () => { let table: Table + let envCleanup: () => void beforeAll(async () => { mocks.licenses.useBudibaseAI() mocks.licenses.useAICustomConfigs() + envCleanup = setEnv({ + OPENAI_API_KEY: "sk-abcdefghijklmnopqrstuvwxyz1234567890abcd", + }) + + mockChatGPTResponse("Mock LLM Response") + table = await config.api.table.save( saveTableRequest({ schema: { @@ -3251,7 +3251,9 @@ if (descriptions.length) { }) afterAll(() => { - jest.unmock("@budibase/pro") + nock.cleanAll() + envCleanup() + mocks.licenses.useCloudFree() }) it("should be able to save a row with an AI column", async () => { diff --git a/packages/server/src/api/routes/tests/viewV2.spec.ts b/packages/server/src/api/routes/tests/viewV2.spec.ts index 6ace7e256b..57efc868e9 100644 --- a/packages/server/src/api/routes/tests/viewV2.spec.ts +++ b/packages/server/src/api/routes/tests/viewV2.spec.ts @@ -1,4 +1,5 @@ import { + AIOperationEnum, ArrayOperator, BasicOperator, BBReferenceFieldSubType, @@ -42,7 +43,9 @@ import { } from "../../../integrations/tests/utils" import merge from "lodash/merge" import { quotas } from "@budibase/pro" -import { context, db, events, roles } from "@budibase/backend-core" +import { context, db, events, roles, setEnv } from "@budibase/backend-core" +import { mockChatGPTResponse } from "../../../tests/utilities/mocks/openai" +import nock from "nock" const descriptions = datasourceDescribe({ exclude: [DatabaseName.MONGODB] }) @@ -100,6 +103,7 @@ if (descriptions.length) { beforeAll(async () => { await config.init() + mocks.licenses.useCloudFree() const ds = await dsProvider() rawDatasource = ds.rawDatasource @@ -109,7 +113,6 @@ if (descriptions.length) { beforeEach(() => { jest.clearAllMocks() - mocks.licenses.useCloudFree() }) describe("view crud", () => { @@ -507,7 +510,6 @@ if (descriptions.length) { }) it("readonly fields can be used on free license", async () => { - mocks.licenses.useCloudFree() const table = await config.api.table.save( saveTableRequest({ schema: { @@ -933,6 +935,94 @@ if (descriptions.length) { } ) }) + + describe("AI fields", () => { + let envCleanup: () => void + beforeAll(() => { + mocks.licenses.useBudibaseAI() + mocks.licenses.useAICustomConfigs() + envCleanup = setEnv({ + OPENAI_API_KEY: "sk-abcdefghijklmnopqrstuvwxyz1234567890abcd", + }) + + mockChatGPTResponse(prompt => { + if (prompt.includes("elephant")) { + return "big" + } + if (prompt.includes("mouse")) { + return "small" + } + if (prompt.includes("whale")) { + return "big" + } + return "unknown" + }) + }) + + afterAll(() => { + nock.cleanAll() + envCleanup() + mocks.licenses.useCloudFree() + }) + + it("can use AI fields in view calculations", async () => { + const table = await config.api.table.save( + saveTableRequest({ + schema: { + animal: { + name: "animal", + type: FieldType.STRING, + }, + bigOrSmall: { + name: "bigOrSmall", + type: FieldType.AI, + operation: AIOperationEnum.CATEGORISE_TEXT, + categories: "big,small", + columns: ["animal"], + }, + }, + }) + ) + + const view = await config.api.viewV2.create({ + tableId: table._id!, + name: generator.guid(), + type: ViewV2Type.CALCULATION, + schema: { + bigOrSmall: { + visible: true, + }, + count: { + visible: true, + calculationType: CalculationType.COUNT, + field: "animal", + }, + }, + }) + + await config.api.row.save(table._id!, { + animal: "elephant", + }) + + await config.api.row.save(table._id!, { + animal: "mouse", + }) + + await config.api.row.save(table._id!, { + animal: "whale", + }) + + const { rows } = await config.api.row.search(view.id, { + sort: "bigOrSmall", + sortOrder: SortOrder.ASCENDING, + }) + expect(rows).toHaveLength(2) + expect(rows[0].bigOrSmall).toEqual("big") + expect(rows[1].bigOrSmall).toEqual("small") + expect(rows[0].count).toEqual(2) + expect(rows[1].count).toEqual(1) + }) + }) }) describe("update", () => { @@ -1836,7 +1926,6 @@ if (descriptions.length) { }, }) - mocks.licenses.useCloudFree() const view = await getDelegate(res) expect(view.schema?.one).toEqual( expect.objectContaining({ visible: true, readonly: true }) diff --git a/packages/server/src/tests/utilities/mocks/openai.ts b/packages/server/src/tests/utilities/mocks/openai.ts new file mode 100644 index 0000000000..b17491808c --- /dev/null +++ b/packages/server/src/tests/utilities/mocks/openai.ts @@ -0,0 +1,46 @@ +import nock from "nock" + +let chatID = 1 + +export function mockChatGPTResponse( + response: string | ((prompt: string) => string) +) { + return nock("https://api.openai.com") + .post("/v1/chat/completions") + .reply(200, (uri, requestBody) => { + let content = response + if (typeof response === "function") { + const messages = (requestBody as any).messages + content = response(messages[0].content) + } + + chatID++ + + return { + id: `chatcmpl-${chatID}`, + object: "chat.completion", + created: Math.floor(Date.now() / 1000), + model: "gpt-4o-mini", + system_fingerprint: `fp_${chatID}`, + choices: [ + { + index: 0, + message: { role: "assistant", content }, + logprobs: null, + finish_reason: "stop", + }, + ], + usage: { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + completion_tokens_details: { + reasoning_tokens: 0, + accepted_prediction_tokens: 0, + rejected_prediction_tokens: 0, + }, + }, + } + }) + .persist() +} diff --git a/packages/server/src/utilities/rowProcessor/utils.ts b/packages/server/src/utilities/rowProcessor/utils.ts index 09d3324ded..7d2f8b49f4 100644 --- a/packages/server/src/utilities/rowProcessor/utils.ts +++ b/packages/server/src/utilities/rowProcessor/utils.ts @@ -160,7 +160,7 @@ export async function processAIColumns( return tracer.trace("processAIColumn", {}, async span => { span?.addTags({ table_id: table._id, column }) - const llmResponse = await llmWrapper.run(prompt!) + const llmResponse = await llmWrapper.run(prompt) return { ...row, [column]: llmResponse, diff --git a/packages/types/src/documents/app/row.ts b/packages/types/src/documents/app/row.ts index 6b6b38a5cf..bb58933b65 100644 --- a/packages/types/src/documents/app/row.ts +++ b/packages/types/src/documents/app/row.ts @@ -154,6 +154,7 @@ export const GroupByTypes = [ FieldType.BOOLEAN, FieldType.DATETIME, FieldType.BIGINT, + FieldType.AI, ] export function canGroupBy(type: FieldType) { diff --git a/packages/types/src/documents/app/table/schema.ts b/packages/types/src/documents/app/table/schema.ts index 771192e2f5..f4a6d8481d 100644 --- a/packages/types/src/documents/app/table/schema.ts +++ b/packages/types/src/documents/app/table/schema.ts @@ -123,7 +123,7 @@ export interface AIFieldMetadata extends BaseFieldSchema { operation: AIOperationEnum columns?: string[] column?: string - categories?: string[] + categories?: string prompt?: string language?: string }