Allow use of AI fields in view calculations.
This commit is contained in:
parent
aec4cc5bfb
commit
b9fb4416bb
|
@ -1 +1 @@
|
||||||
Subproject commit 32d84f109d4edc526145472a7446327312151442
|
Subproject commit 45f5b6fe9bbdbdf502581740ab43b82e8153260f
|
|
@ -162,6 +162,12 @@ export async function finaliseRow(
|
||||||
dynamic: false,
|
dynamic: false,
|
||||||
contextRows: [enrichedRow],
|
contextRows: [enrichedRow],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const flag1 = await features.isEnabled(FeatureFlag.BUDIBASE_AI)
|
||||||
|
const flag2 = await pro.features.isBudibaseAIEnabled()
|
||||||
|
const flag3 = await features.isEnabled(FeatureFlag.AI_CUSTOM_CONFIGS)
|
||||||
|
const flag4 = await pro.features.isAICustomConfigsEnabled()
|
||||||
|
|
||||||
const aiEnabled =
|
const aiEnabled =
|
||||||
((await features.isEnabled(FeatureFlag.BUDIBASE_AI)) &&
|
((await features.isEnabled(FeatureFlag.BUDIBASE_AI)) &&
|
||||||
(await pro.features.isBudibaseAIEnabled())) ||
|
(await pro.features.isBudibaseAIEnabled())) ||
|
||||||
|
|
|
@ -8,7 +8,13 @@ import {
|
||||||
import tk from "timekeeper"
|
import tk from "timekeeper"
|
||||||
import emitter from "../../../../src/events"
|
import emitter from "../../../../src/events"
|
||||||
import { outputProcessing } from "../../../utilities/rowProcessor"
|
import { outputProcessing } from "../../../utilities/rowProcessor"
|
||||||
import { context, InternalTable, tenancy, utils } from "@budibase/backend-core"
|
import {
|
||||||
|
context,
|
||||||
|
setEnv,
|
||||||
|
InternalTable,
|
||||||
|
tenancy,
|
||||||
|
utils,
|
||||||
|
} from "@budibase/backend-core"
|
||||||
import { quotas } from "@budibase/pro"
|
import { quotas } from "@budibase/pro"
|
||||||
import {
|
import {
|
||||||
AIOperationEnum,
|
AIOperationEnum,
|
||||||
|
@ -42,19 +48,8 @@ import { InternalTables } from "../../../db/utils"
|
||||||
import { withEnv } from "../../../environment"
|
import { withEnv } from "../../../environment"
|
||||||
import { JsTimeoutError } from "@budibase/string-templates"
|
import { JsTimeoutError } from "@budibase/string-templates"
|
||||||
import { isDate } from "../../../utilities"
|
import { isDate } from "../../../utilities"
|
||||||
|
import nock from "nock"
|
||||||
jest.mock("@budibase/pro", () => ({
|
import { mockChatGPTResponse } from "../../../tests/utilities/mocks/openai"
|
||||||
...jest.requireActual("@budibase/pro"),
|
|
||||||
ai: {
|
|
||||||
LargeLanguageModel: {
|
|
||||||
forCurrentTenant: async () => ({
|
|
||||||
llm: {},
|
|
||||||
run: jest.fn(() => `Mock LLM Response`),
|
|
||||||
buildPromptFromAIOperation: jest.fn(),
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
|
|
||||||
const timestamp = new Date("2023-01-26T11:48:57.597Z").toISOString()
|
const timestamp = new Date("2023-01-26T11:48:57.597Z").toISOString()
|
||||||
tk.freeze(timestamp)
|
tk.freeze(timestamp)
|
||||||
|
@ -99,6 +94,8 @@ if (descriptions.length) {
|
||||||
const ds = await dsProvider()
|
const ds = await dsProvider()
|
||||||
datasource = ds.datasource
|
datasource = ds.datasource
|
||||||
client = ds.client
|
client = ds.client
|
||||||
|
|
||||||
|
mocks.licenses.useCloudFree()
|
||||||
})
|
})
|
||||||
|
|
||||||
afterAll(async () => {
|
afterAll(async () => {
|
||||||
|
@ -172,10 +169,6 @@ if (descriptions.length) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
beforeEach(async () => {
|
|
||||||
mocks.licenses.useCloudFree()
|
|
||||||
})
|
|
||||||
|
|
||||||
const getRowUsage = async () => {
|
const getRowUsage = async () => {
|
||||||
const { total } = await config.doInContext(undefined, () =>
|
const { total } = await config.doInContext(undefined, () =>
|
||||||
quotas.getCurrentUsageValues(
|
quotas.getCurrentUsageValues(
|
||||||
|
@ -3224,10 +3217,17 @@ if (descriptions.length) {
|
||||||
isInternal &&
|
isInternal &&
|
||||||
describe("AI fields", () => {
|
describe("AI fields", () => {
|
||||||
let table: Table
|
let table: Table
|
||||||
|
let envCleanup: () => void
|
||||||
|
|
||||||
beforeAll(async () => {
|
beforeAll(async () => {
|
||||||
mocks.licenses.useBudibaseAI()
|
mocks.licenses.useBudibaseAI()
|
||||||
mocks.licenses.useAICustomConfigs()
|
mocks.licenses.useAICustomConfigs()
|
||||||
|
envCleanup = setEnv({
|
||||||
|
OPENAI_API_KEY: "sk-abcdefghijklmnopqrstuvwxyz1234567890abcd",
|
||||||
|
})
|
||||||
|
|
||||||
|
mockChatGPTResponse("Mock LLM Response")
|
||||||
|
|
||||||
table = await config.api.table.save(
|
table = await config.api.table.save(
|
||||||
saveTableRequest({
|
saveTableRequest({
|
||||||
schema: {
|
schema: {
|
||||||
|
@ -3251,7 +3251,9 @@ if (descriptions.length) {
|
||||||
})
|
})
|
||||||
|
|
||||||
afterAll(() => {
|
afterAll(() => {
|
||||||
jest.unmock("@budibase/pro")
|
nock.cleanAll()
|
||||||
|
envCleanup()
|
||||||
|
mocks.licenses.useCloudFree()
|
||||||
})
|
})
|
||||||
|
|
||||||
it("should be able to save a row with an AI column", async () => {
|
it("should be able to save a row with an AI column", async () => {
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import {
|
import {
|
||||||
|
AIOperationEnum,
|
||||||
ArrayOperator,
|
ArrayOperator,
|
||||||
BasicOperator,
|
BasicOperator,
|
||||||
BBReferenceFieldSubType,
|
BBReferenceFieldSubType,
|
||||||
|
@ -42,7 +43,9 @@ import {
|
||||||
} from "../../../integrations/tests/utils"
|
} from "../../../integrations/tests/utils"
|
||||||
import merge from "lodash/merge"
|
import merge from "lodash/merge"
|
||||||
import { quotas } from "@budibase/pro"
|
import { quotas } from "@budibase/pro"
|
||||||
import { context, db, events, roles } from "@budibase/backend-core"
|
import { context, db, events, roles, setEnv } from "@budibase/backend-core"
|
||||||
|
import { mockChatGPTResponse } from "../../../tests/utilities/mocks/openai"
|
||||||
|
import nock from "nock"
|
||||||
|
|
||||||
const descriptions = datasourceDescribe({ exclude: [DatabaseName.MONGODB] })
|
const descriptions = datasourceDescribe({ exclude: [DatabaseName.MONGODB] })
|
||||||
|
|
||||||
|
@ -100,6 +103,7 @@ if (descriptions.length) {
|
||||||
|
|
||||||
beforeAll(async () => {
|
beforeAll(async () => {
|
||||||
await config.init()
|
await config.init()
|
||||||
|
mocks.licenses.useCloudFree()
|
||||||
|
|
||||||
const ds = await dsProvider()
|
const ds = await dsProvider()
|
||||||
rawDatasource = ds.rawDatasource
|
rawDatasource = ds.rawDatasource
|
||||||
|
@ -109,7 +113,6 @@ if (descriptions.length) {
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
mocks.licenses.useCloudFree()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("view crud", () => {
|
describe("view crud", () => {
|
||||||
|
@ -507,7 +510,6 @@ if (descriptions.length) {
|
||||||
})
|
})
|
||||||
|
|
||||||
it("readonly fields can be used on free license", async () => {
|
it("readonly fields can be used on free license", async () => {
|
||||||
mocks.licenses.useCloudFree()
|
|
||||||
const table = await config.api.table.save(
|
const table = await config.api.table.save(
|
||||||
saveTableRequest({
|
saveTableRequest({
|
||||||
schema: {
|
schema: {
|
||||||
|
@ -933,6 +935,94 @@ if (descriptions.length) {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe("AI fields", () => {
|
||||||
|
let envCleanup: () => void
|
||||||
|
beforeAll(() => {
|
||||||
|
mocks.licenses.useBudibaseAI()
|
||||||
|
mocks.licenses.useAICustomConfigs()
|
||||||
|
envCleanup = setEnv({
|
||||||
|
OPENAI_API_KEY: "sk-abcdefghijklmnopqrstuvwxyz1234567890abcd",
|
||||||
|
})
|
||||||
|
|
||||||
|
mockChatGPTResponse(prompt => {
|
||||||
|
if (prompt.includes("elephant")) {
|
||||||
|
return "big"
|
||||||
|
}
|
||||||
|
if (prompt.includes("mouse")) {
|
||||||
|
return "small"
|
||||||
|
}
|
||||||
|
if (prompt.includes("whale")) {
|
||||||
|
return "big"
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
nock.cleanAll()
|
||||||
|
envCleanup()
|
||||||
|
mocks.licenses.useCloudFree()
|
||||||
|
})
|
||||||
|
|
||||||
|
it("can use AI fields in view calculations", async () => {
|
||||||
|
const table = await config.api.table.save(
|
||||||
|
saveTableRequest({
|
||||||
|
schema: {
|
||||||
|
animal: {
|
||||||
|
name: "animal",
|
||||||
|
type: FieldType.STRING,
|
||||||
|
},
|
||||||
|
bigOrSmall: {
|
||||||
|
name: "bigOrSmall",
|
||||||
|
type: FieldType.AI,
|
||||||
|
operation: AIOperationEnum.CATEGORISE_TEXT,
|
||||||
|
categories: "big,small",
|
||||||
|
columns: ["animal"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
const view = await config.api.viewV2.create({
|
||||||
|
tableId: table._id!,
|
||||||
|
name: generator.guid(),
|
||||||
|
type: ViewV2Type.CALCULATION,
|
||||||
|
schema: {
|
||||||
|
bigOrSmall: {
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
count: {
|
||||||
|
visible: true,
|
||||||
|
calculationType: CalculationType.COUNT,
|
||||||
|
field: "animal",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
await config.api.row.save(table._id!, {
|
||||||
|
animal: "elephant",
|
||||||
|
})
|
||||||
|
|
||||||
|
await config.api.row.save(table._id!, {
|
||||||
|
animal: "mouse",
|
||||||
|
})
|
||||||
|
|
||||||
|
await config.api.row.save(table._id!, {
|
||||||
|
animal: "whale",
|
||||||
|
})
|
||||||
|
|
||||||
|
const { rows } = await config.api.row.search(view.id, {
|
||||||
|
sort: "bigOrSmall",
|
||||||
|
sortOrder: SortOrder.ASCENDING,
|
||||||
|
})
|
||||||
|
expect(rows).toHaveLength(2)
|
||||||
|
expect(rows[0].bigOrSmall).toEqual("big")
|
||||||
|
expect(rows[1].bigOrSmall).toEqual("small")
|
||||||
|
expect(rows[0].count).toEqual(2)
|
||||||
|
expect(rows[1].count).toEqual(1)
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("update", () => {
|
describe("update", () => {
|
||||||
|
@ -1836,7 +1926,6 @@ if (descriptions.length) {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
mocks.licenses.useCloudFree()
|
|
||||||
const view = await getDelegate(res)
|
const view = await getDelegate(res)
|
||||||
expect(view.schema?.one).toEqual(
|
expect(view.schema?.one).toEqual(
|
||||||
expect.objectContaining({ visible: true, readonly: true })
|
expect.objectContaining({ visible: true, readonly: true })
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
import nock from "nock"
|
||||||
|
|
||||||
|
let chatID = 1
|
||||||
|
|
||||||
|
export function mockChatGPTResponse(
|
||||||
|
response: string | ((prompt: string) => string)
|
||||||
|
) {
|
||||||
|
return nock("https://api.openai.com")
|
||||||
|
.post("/v1/chat/completions")
|
||||||
|
.reply(200, (uri, requestBody) => {
|
||||||
|
let content = response
|
||||||
|
if (typeof response === "function") {
|
||||||
|
const messages = (requestBody as any).messages
|
||||||
|
content = response(messages[0].content)
|
||||||
|
}
|
||||||
|
|
||||||
|
chatID++
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: `chatcmpl-${chatID}`,
|
||||||
|
object: "chat.completion",
|
||||||
|
created: Math.floor(Date.now() / 1000),
|
||||||
|
model: "gpt-4o-mini",
|
||||||
|
system_fingerprint: `fp_${chatID}`,
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
message: { role: "assistant", content },
|
||||||
|
logprobs: null,
|
||||||
|
finish_reason: "stop",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
completion_tokens_details: {
|
||||||
|
reasoning_tokens: 0,
|
||||||
|
accepted_prediction_tokens: 0,
|
||||||
|
rejected_prediction_tokens: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.persist()
|
||||||
|
}
|
|
@ -160,7 +160,7 @@ export async function processAIColumns<T extends Row | Row[]>(
|
||||||
|
|
||||||
return tracer.trace("processAIColumn", {}, async span => {
|
return tracer.trace("processAIColumn", {}, async span => {
|
||||||
span?.addTags({ table_id: table._id, column })
|
span?.addTags({ table_id: table._id, column })
|
||||||
const llmResponse = await llmWrapper.run(prompt!)
|
const llmResponse = await llmWrapper.run(prompt)
|
||||||
return {
|
return {
|
||||||
...row,
|
...row,
|
||||||
[column]: llmResponse,
|
[column]: llmResponse,
|
||||||
|
|
|
@ -154,6 +154,7 @@ export const GroupByTypes = [
|
||||||
FieldType.BOOLEAN,
|
FieldType.BOOLEAN,
|
||||||
FieldType.DATETIME,
|
FieldType.DATETIME,
|
||||||
FieldType.BIGINT,
|
FieldType.BIGINT,
|
||||||
|
FieldType.AI,
|
||||||
]
|
]
|
||||||
|
|
||||||
export function canGroupBy(type: FieldType) {
|
export function canGroupBy(type: FieldType) {
|
||||||
|
|
|
@ -123,7 +123,7 @@ export interface AIFieldMetadata extends BaseFieldSchema {
|
||||||
operation: AIOperationEnum
|
operation: AIOperationEnum
|
||||||
columns?: string[]
|
columns?: string[]
|
||||||
column?: string
|
column?: string
|
||||||
categories?: string[]
|
categories?: string
|
||||||
prompt?: string
|
prompt?: string
|
||||||
language?: string
|
language?: string
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue