diff --git a/README.md b/README.md index 82ebeec..f441c39 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Prisma generator for model factories. - [has-many / has-one relation](#has-many--has-one-relation) - [Custom scalar field generation](#custom-scalar-field-generation) - [Traits](#traits) + - [Callbacks](#callbacks) - [Field value precedence](#field-value-precedence) - [More examples](#more-examples) - [Generator configuration](#generator-configuration) @@ -375,6 +376,64 @@ Multiple traits are also available: await UserFactory.use("someTrait", "anotherTrait").create(); ``` +### Callbacks + +You can set callback function before or after factory execution. + +```ts +const UserFactory = defineUserFactory({ + onAfterCreate: async user => { + await PostFactory.craete({ + author: { connect: uesr }, + }); + }, +}); + +await UserFactory.create(); +``` + +Callback functions are also available within trait definition. + +```ts +const UserFactory = defineUserFactory({ + traits: { + withComment: { + onAfterCreate: async user => { + await PostFactory.craete({ + author: { connect: uesr }, + }); + }, + }, + }, +}); + +await UserFactory.create(); +await UserFactory.use("withComment").create(); +``` + +Note: The above code is to explain the callback. If you want to create association, first consider to use `defaultData` and `trait.data` option as in [has-many / has-one relation](#has-many--has-one-relation). + +The following three types are available as callback function: + +```ts +const UserFactory = defineUserFactory({ + onAfterBuild: async createInput => { + // do something + }, + onBeforeCreate: async createInput => { + // do something + }, + onAfterCreate: async createdData => { + // do something + }, +}); +``` + +And here, the parameter types are: + +- `createInput` is assignable to model create function parameter (e.g. `Prsima.UserCreateInput`). +- `createdData` is resolved object by model create function (e.g. `User` model type) + ### Field value precedence Each field is determined in the following priority order(lower numbers have higher priority): diff --git a/examples/example-prj/src/__generated__/fabbrica/index.d.ts b/examples/example-prj/src/__generated__/fabbrica/index.d.ts index 545d8d0..5fcc7dd 100644 --- a/examples/example-prj/src/__generated__/fabbrica/index.d.ts +++ b/examples/example-prj/src/__generated__/fabbrica/index.d.ts @@ -8,6 +8,11 @@ export { resetSequence, registerScalarFieldValueGenerator, resetScalarFieldValue type BuildDataOptions = { readonly seq: number; }; +type CallbackDefineOptions = { + onAfterBuild?: (createInput: TCreateInput) => void | PromiseLike; + onBeforeCreate?: (createInput: TCreateInput) => void | PromiseLike; + onAfterCreate?: (created: TCreated) => void | PromiseLike; +}; export declare const initialize: (options: import("@quramy/prisma-fabbrica/lib/initialize").InitializeOptions) => void; type UserFactoryDefineInput = { id?: string; @@ -18,14 +23,15 @@ type UserFactoryDefineInput = { posts?: Prisma.PostCreateNestedManyWithoutAuthorInput; comments?: Prisma.CommentCreateNestedManyWithoutAuthorInput; }; +type UserFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; type UserFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: UserFactoryTrait; }; -}; +} & CallbackDefineOptions; type UserTraitKeys = keyof TOptions["traits"]; export interface UserFactoryInterfaceWithoutTraits { readonly _factoryFor: "User"; @@ -60,14 +66,15 @@ type PostFactoryDefineInput = { comments?: Prisma.CommentCreateNestedManyWithoutPostInput; categories?: Prisma.CategoryCreateNestedManyWithoutPostsInput; }; +type PostFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; type PostFactoryDefineOptions = { defaultData: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: PostFactoryTrait; }; -}; +} & CallbackDefineOptions; type PostTraitKeys = keyof TOptions["traits"]; export interface PostFactoryInterfaceWithoutTraits { readonly _factoryFor: "Post"; @@ -105,14 +112,15 @@ type CommentFactoryDefineInput = { post: CommentpostFactory | Prisma.PostCreateNestedOneWithoutCommentsInput; author: CommentauthorFactory | Prisma.UserCreateNestedOneWithoutCommentsInput; }; +type CommentFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; type CommentFactoryDefineOptions = { defaultData: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: CommentFactoryTrait; }; -}; +} & CallbackDefineOptions; type CommentTraitKeys = keyof TOptions["traits"]; export interface CommentFactoryInterfaceWithoutTraits { readonly _factoryFor: "Comment"; @@ -139,14 +147,15 @@ type CategoryFactoryDefineInput = { name?: string; posts?: Prisma.PostCreateNestedManyWithoutCategoriesInput; }; +type CategoryFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; type CategoryFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: CategoryFactoryTrait; }; -}; +} & CallbackDefineOptions; type CategoryTraitKeys = keyof TOptions["traits"]; export interface CategoryFactoryInterfaceWithoutTraits { readonly _factoryFor: "Category"; diff --git a/examples/example-prj/src/__generated__/fabbrica/index.js b/examples/example-prj/src/__generated__/fabbrica/index.js index 4a84950..98a156c 100644 --- a/examples/example-prj/src/__generated__/fabbrica/index.js +++ b/examples/example-prj/src/__generated__/fabbrica/index.js @@ -61,11 +61,23 @@ function autoGenerateUserScalarsOrEnums({ seq }) { name: (0, internal_1.getScalarFieldValueGenerator)().String({ modelName: "User", fieldName: "name", isId: false, isUnique: false, seq }) }; } -function defineUserFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }) { +function defineUserFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }) { const getFactoryWithTraits = (traitKeys = []) => { const seqKey = {}; const getSeq = () => (0, internal_1.getSequenceCounter)(seqKey); const screen = (0, internal_1.createScreener)("User", modelFieldDefinitions); + const handleAfterBuild = (0, internal_1.createCallbackChain)([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = (0, internal_1.createCallbackChain)([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = (0, internal_1.createCallbackChain)([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateUserScalarsOrEnums({ seq }); @@ -81,6 +93,7 @@ function defineUserFactoryInternal({ defaultData: defaultDataResolver, traits: t }, resolveValue({ seq })); const defaultAssociations = {}; const data = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData) => Promise.all((0, internal_1.normalizeList)(inputData).map(data => build(data))); @@ -89,7 +102,10 @@ function defineUserFactoryInternal({ defaultData: defaultDataResolver, traits: t }); const create = async (inputData = {}) => { const data = await build(inputData).then(screen); - return await getClient().user.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().user.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData) => Promise.all((0, internal_1.normalizeList)(inputData).map(data => create(data))); const createForConnect = (inputData = {}) => create(inputData).then(pickForConnect); @@ -132,11 +148,23 @@ function autoGeneratePostScalarsOrEnums({ seq }) { title: (0, internal_1.getScalarFieldValueGenerator)().String({ modelName: "Post", fieldName: "title", isId: false, isUnique: false, seq }) }; } -function definePostFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }) { +function definePostFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }) { const getFactoryWithTraits = (traitKeys = []) => { const seqKey = {}; const getSeq = () => (0, internal_1.getSequenceCounter)(seqKey); const screen = (0, internal_1.createScreener)("Post", modelFieldDefinitions); + const handleAfterBuild = (0, internal_1.createCallbackChain)([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = (0, internal_1.createCallbackChain)([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = (0, internal_1.createCallbackChain)([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData = {}) => { const seq = getSeq(); const requiredScalarData = autoGeneratePostScalarsOrEnums({ seq }); @@ -156,6 +184,7 @@ function definePostFactoryInternal({ defaultData: defaultDataResolver, traits: t } : defaultData.author }; const data = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData) => Promise.all((0, internal_1.normalizeList)(inputData).map(data => build(data))); @@ -164,7 +193,10 @@ function definePostFactoryInternal({ defaultData: defaultDataResolver, traits: t }); const create = async (inputData = {}) => { const data = await build(inputData).then(screen); - return await getClient().post.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().post.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData) => Promise.all((0, internal_1.normalizeList)(inputData).map(data => create(data))); const createForConnect = (inputData = {}) => create(inputData).then(pickForConnect); @@ -210,11 +242,23 @@ function autoGenerateCommentScalarsOrEnums({ seq }) { body: (0, internal_1.getScalarFieldValueGenerator)().String({ modelName: "Comment", fieldName: "body", isId: false, isUnique: false, seq }) }; } -function defineCommentFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }) { +function defineCommentFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }) { const getFactoryWithTraits = (traitKeys = []) => { const seqKey = {}; const getSeq = () => (0, internal_1.getSequenceCounter)(seqKey); const screen = (0, internal_1.createScreener)("Comment", modelFieldDefinitions); + const handleAfterBuild = (0, internal_1.createCallbackChain)([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = (0, internal_1.createCallbackChain)([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = (0, internal_1.createCallbackChain)([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateCommentScalarsOrEnums({ seq }); @@ -237,6 +281,7 @@ function defineCommentFactoryInternal({ defaultData: defaultDataResolver, traits } : defaultData.author }; const data = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData) => Promise.all((0, internal_1.normalizeList)(inputData).map(data => build(data))); @@ -245,7 +290,10 @@ function defineCommentFactoryInternal({ defaultData: defaultDataResolver, traits }); const create = async (inputData = {}) => { const data = await build(inputData).then(screen); - return await getClient().comment.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().comment.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData) => Promise.all((0, internal_1.normalizeList)(inputData).map(data => create(data))); const createForConnect = (inputData = {}) => create(inputData).then(pickForConnect); @@ -285,11 +333,23 @@ function autoGenerateCategoryScalarsOrEnums({ seq }) { name: (0, internal_1.getScalarFieldValueGenerator)().String({ modelName: "Category", fieldName: "name", isId: false, isUnique: true, seq }) }; } -function defineCategoryFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }) { +function defineCategoryFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }) { const getFactoryWithTraits = (traitKeys = []) => { const seqKey = {}; const getSeq = () => (0, internal_1.getSequenceCounter)(seqKey); const screen = (0, internal_1.createScreener)("Category", modelFieldDefinitions); + const handleAfterBuild = (0, internal_1.createCallbackChain)([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = (0, internal_1.createCallbackChain)([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = (0, internal_1.createCallbackChain)([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateCategoryScalarsOrEnums({ seq }); @@ -305,6 +365,7 @@ function defineCategoryFactoryInternal({ defaultData: defaultDataResolver, trait }, resolveValue({ seq })); const defaultAssociations = {}; const data = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData) => Promise.all((0, internal_1.normalizeList)(inputData).map(data => build(data))); @@ -313,7 +374,10 @@ function defineCategoryFactoryInternal({ defaultData: defaultDataResolver, trait }); const create = async (inputData = {}) => { const data = await build(inputData).then(screen); - return await getClient().category.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().category.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData) => Promise.all((0, internal_1.normalizeList)(inputData).map(data => create(data))); const createForConnect = (inputData = {}) => create(inputData).then(pickForConnect); diff --git a/examples/example-prj/src/trait.test.ts b/examples/example-prj/src/trait.test.ts index b0f8ff3..74335e0 100644 --- a/examples/example-prj/src/trait.test.ts +++ b/examples/example-prj/src/trait.test.ts @@ -2,12 +2,28 @@ import { defineUserFactory, definePostFactory, defineCommentFactory, - CommentFactoryInterface, + type PostFactoryInterface, + type CommentFactoryInterface, } from "./__generated__/fabbrica"; const prisma = jestPrisma.client; -export const UserFactory = defineUserFactory(); +export const UserFactory = defineUserFactory({ + traits: { + withSelfCommentedPost: { + onAfterCreate: async user => { + await getPostFactory().create({ + author: { + connect: user, + }, + comments: { + create: [await getCommentFactory().build({ author: { connect: user } })], + }, + }); + }, + }, + }, +}); export const PostFactory = definePostFactory({ defaultData: { @@ -43,6 +59,10 @@ export const CommentFactory = defineCommentFactory({ }, }); +function getPostFactory(): PostFactoryInterface { + return PostFactory; +} + function getCommentFactory(): CommentFactoryInterface { return CommentFactory; } @@ -77,5 +97,12 @@ describe("factories", () => { await expect(prisma.comment.count({ where: { postId: post2.id } })).resolves.toBe(1); }); }); + + describe("trait and callback", () => { + test("Execute other factory in callback", async () => { + const { id: userId } = await UserFactory.use("withSelfCommentedPost").create(); + await expect(prisma.comment.count({ where: { authorId: userId } })).resolves.toBe(1); + }); + }); }); }); diff --git a/packages/artifact-testing/fixtures/field-variation/__generated__/fabbrica/index.ts b/packages/artifact-testing/fixtures/field-variation/__generated__/fabbrica/index.ts index 74fa05d..79dec04 100644 --- a/packages/artifact-testing/fixtures/field-variation/__generated__/fabbrica/index.ts +++ b/packages/artifact-testing/fixtures/field-variation/__generated__/fabbrica/index.ts @@ -6,13 +6,19 @@ import type { Role } from "../client"; import type { Status } from "../client"; import { Prisma } from "../client"; import type { PrismaClient } from "../client"; -import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, } from "@quramy/prisma-fabbrica/lib/internal"; +import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, createCallbackChain, } from "@quramy/prisma-fabbrica/lib/internal"; export { resetSequence, registerScalarFieldValueGenerator, resetScalarFieldValueGenerator } from "@quramy/prisma-fabbrica/lib/internal"; type BuildDataOptions = { readonly seq: number; }; +type CallbackDefineOptions = { + onAfterBuild?: (createInput: TCreateInput) => void | PromiseLike; + onBeforeCreate?: (createInput: TCreateInput) => void | PromiseLike; + onAfterCreate?: (created: TCreated) => void | PromiseLike; +}; + const initializer = createInitializer(); const { getClient } = initializer; @@ -49,14 +55,16 @@ type UserFactoryDefineInput = { status?: Status | null; }; +type UserFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type UserFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: UserFactoryTrait; }; -}; +} & CallbackDefineOptions; type UserTraitKeys = keyof TOptions["traits"]; @@ -84,11 +92,23 @@ function autoGenerateUserScalarsOrEnums({ seq }: { }; } -function defineUserFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { +function defineUserFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly UserTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("User", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateUserScalarsOrEnums({ seq }); @@ -104,6 +124,7 @@ function defineUserFactoryInternal({ }, resolveValue({ seq })); const defaultAssociations = {}; const data: Prisma.UserCreateInput = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -112,7 +133,10 @@ function defineUserFactoryInternal({ }); const create = async (inputData: Partial = {}) => { const data = await build(inputData).then(screen); - return await getClient().user.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().user.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); @@ -157,14 +181,16 @@ type ComplexIdModelFactoryDefineInput = { lastName?: string; }; +type ComplexIdModelFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type ComplexIdModelFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: ComplexIdModelFactoryTrait; }; -}; +} & CallbackDefineOptions; type ComplexIdModelTraitKeys = keyof TOptions["traits"]; @@ -192,11 +218,23 @@ function autoGenerateComplexIdModelScalarsOrEnums({ seq }: { }; } -function defineComplexIdModelFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): ComplexIdModelFactoryInterface { +function defineComplexIdModelFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): ComplexIdModelFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly ComplexIdModelTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("ComplexIdModel", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateComplexIdModelScalarsOrEnums({ seq }); @@ -212,6 +250,7 @@ function defineComplexIdModelFactoryInternal[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -221,7 +260,10 @@ function defineComplexIdModelFactoryInternal = {}) => { const data = await build(inputData).then(screen); - return await getClient().complexIdModel.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().complexIdModel.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); @@ -290,14 +332,16 @@ type FieldTypePatternModelFactoryDefineInput = { nullableBigInt?: (bigint | number) | null; }; +type FieldTypePatternModelFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type FieldTypePatternModelFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: FieldTypePatternModelFactoryTrait; }; -}; +} & CallbackDefineOptions; type FieldTypePatternModelTraitKeys = keyof TOptions["traits"]; @@ -332,11 +376,23 @@ function autoGenerateFieldTypePatternModelScalarsOrEnums({ seq }: { }; } -function defineFieldTypePatternModelFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): FieldTypePatternModelFactoryInterface { +function defineFieldTypePatternModelFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): FieldTypePatternModelFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly FieldTypePatternModelTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("FieldTypePatternModel", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateFieldTypePatternModelScalarsOrEnums({ seq }); @@ -352,6 +408,7 @@ function defineFieldTypePatternModelFactoryInternal[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -360,7 +417,10 @@ function defineFieldTypePatternModelFactoryInternal = {}) => { const data = await build(inputData).then(screen); - return await getClient().fieldTypePatternModel.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().fieldTypePatternModel.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); @@ -403,14 +463,16 @@ type NoPkModelFactoryDefineInput = { id?: number; }; +type NoPkModelFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type NoPkModelFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: NoPkModelFactoryTrait; }; -}; +} & CallbackDefineOptions; type NoPkModelTraitKeys = keyof TOptions["traits"]; @@ -437,11 +499,23 @@ function autoGenerateNoPkModelScalarsOrEnums({ seq }: { }; } -function defineNoPkModelFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): NoPkModelFactoryInterface { +function defineNoPkModelFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): NoPkModelFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly NoPkModelTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("NoPkModel", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateNoPkModelScalarsOrEnums({ seq }); @@ -457,6 +531,7 @@ function defineNoPkModelFactoryInternal[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -465,7 +540,10 @@ function defineNoPkModelFactoryInternal = {}) => { const data = await build(inputData).then(screen); - return await getClient().noPkModel.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().noPkModel.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); diff --git a/packages/artifact-testing/fixtures/relations-many-to-many/__generated__/fabbrica/index.ts b/packages/artifact-testing/fixtures/relations-many-to-many/__generated__/fabbrica/index.ts index 2042882..498bcfa 100644 --- a/packages/artifact-testing/fixtures/relations-many-to-many/__generated__/fabbrica/index.ts +++ b/packages/artifact-testing/fixtures/relations-many-to-many/__generated__/fabbrica/index.ts @@ -2,13 +2,19 @@ import type { Post } from "../client"; import type { Category } from "../client"; import { Prisma } from "../client"; import type { PrismaClient } from "../client"; -import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, } from "@quramy/prisma-fabbrica/lib/internal"; +import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, createCallbackChain, } from "@quramy/prisma-fabbrica/lib/internal"; export { resetSequence, registerScalarFieldValueGenerator, resetScalarFieldValueGenerator } from "@quramy/prisma-fabbrica/lib/internal"; type BuildDataOptions = { readonly seq: number; }; +type CallbackDefineOptions = { + onAfterBuild?: (createInput: TCreateInput) => void | PromiseLike; + onBeforeCreate?: (createInput: TCreateInput) => void | PromiseLike; + onAfterCreate?: (created: TCreated) => void | PromiseLike; +}; + const initializer = createInitializer(); const { getClient } = initializer; @@ -42,14 +48,16 @@ type PostFactoryDefineInput = { categories?: Prisma.CategoryCreateNestedManyWithoutPostsInput; }; +type PostFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type PostFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: PostFactoryTrait; }; -}; +} & CallbackDefineOptions; type PostTraitKeys = keyof TOptions["traits"]; @@ -77,11 +85,23 @@ function autoGeneratePostScalarsOrEnums({ seq }: { }; } -function definePostFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): PostFactoryInterface { +function definePostFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): PostFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly PostTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("Post", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGeneratePostScalarsOrEnums({ seq }); @@ -97,6 +117,7 @@ function definePostFactoryInternal({ }, resolveValue({ seq })); const defaultAssociations = {}; const data: Prisma.PostCreateInput = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -105,7 +126,10 @@ function definePostFactoryInternal({ }); const create = async (inputData: Partial = {}) => { const data = await build(inputData).then(screen); - return await getClient().post.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().post.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); @@ -151,14 +175,16 @@ type CategoryFactoryDefineInput = { posts?: Prisma.PostCreateNestedManyWithoutCategoriesInput; }; +type CategoryFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type CategoryFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: CategoryFactoryTrait; }; -}; +} & CallbackDefineOptions; type CategoryTraitKeys = keyof TOptions["traits"]; @@ -186,11 +212,23 @@ function autoGenerateCategoryScalarsOrEnums({ seq }: { }; } -function defineCategoryFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): CategoryFactoryInterface { +function defineCategoryFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): CategoryFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly CategoryTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("Category", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateCategoryScalarsOrEnums({ seq }); @@ -206,6 +244,7 @@ function defineCategoryFactoryInternal[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -214,7 +253,10 @@ function defineCategoryFactoryInternal = {}) => { const data = await build(inputData).then(screen); - return await getClient().category.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().category.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); diff --git a/packages/artifact-testing/fixtures/relations-one-to-many/__generated__/fabbrica/index.ts b/packages/artifact-testing/fixtures/relations-one-to-many/__generated__/fabbrica/index.ts index ae26243..05f690a 100644 --- a/packages/artifact-testing/fixtures/relations-one-to-many/__generated__/fabbrica/index.ts +++ b/packages/artifact-testing/fixtures/relations-one-to-many/__generated__/fabbrica/index.ts @@ -3,13 +3,19 @@ import type { Post } from "../client"; import type { Review } from "../client"; import { Prisma } from "../client"; import type { PrismaClient } from "../client"; -import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, } from "@quramy/prisma-fabbrica/lib/internal"; +import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, createCallbackChain, } from "@quramy/prisma-fabbrica/lib/internal"; export { resetSequence, registerScalarFieldValueGenerator, resetScalarFieldValueGenerator } from "@quramy/prisma-fabbrica/lib/internal"; type BuildDataOptions = { readonly seq: number; }; +type CallbackDefineOptions = { + onAfterBuild?: (createInput: TCreateInput) => void | PromiseLike; + onBeforeCreate?: (createInput: TCreateInput) => void | PromiseLike; + onAfterCreate?: (created: TCreated) => void | PromiseLike; +}; + const initializer = createInitializer(); const { getClient } = initializer; @@ -63,14 +69,16 @@ type UserFactoryDefineInput = { reviews?: Prisma.ReviewCreateNestedManyWithoutReviewerInput; }; +type UserFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type UserFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: UserFactoryTrait; }; -}; +} & CallbackDefineOptions; type UserTraitKeys = keyof TOptions["traits"]; @@ -98,11 +106,23 @@ function autoGenerateUserScalarsOrEnums({ seq }: { }; } -function defineUserFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { +function defineUserFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly UserTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("User", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateUserScalarsOrEnums({ seq }); @@ -118,6 +138,7 @@ function defineUserFactoryInternal({ }, resolveValue({ seq })); const defaultAssociations = {}; const data: Prisma.UserCreateInput = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -126,7 +147,10 @@ function defineUserFactoryInternal({ }); const create = async (inputData: Partial = {}) => { const data = await build(inputData).then(screen); - return await getClient().user.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().user.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); @@ -178,14 +202,16 @@ type PostFactoryDefineInput = { reviews?: Prisma.ReviewCreateNestedManyWithoutPostInput; }; +type PostFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type PostFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: PostFactoryTrait; }; -}; +} & CallbackDefineOptions; function isPostauthorFactory(x: PostauthorFactory | Prisma.UserCreateNestedOneWithoutPostsInput | undefined): x is PostauthorFactory { return (x as any)?._factoryFor === "User"; @@ -217,11 +243,23 @@ function autoGeneratePostScalarsOrEnums({ seq }: { }; } -function definePostFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): PostFactoryInterface { +function definePostFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): PostFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly PostTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("Post", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGeneratePostScalarsOrEnums({ seq }); @@ -241,6 +279,7 @@ function definePostFactoryInternal({ } : defaultData.author }; const data: Prisma.PostCreateInput = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -249,7 +288,10 @@ function definePostFactoryInternal({ }); const create = async (inputData: Partial = {}) => { const data = await build(inputData).then(screen); - return await getClient().post.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().post.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); @@ -306,14 +348,16 @@ type ReviewFactoryDefineInput = { reviewer: ReviewreviewerFactory | Prisma.UserCreateNestedOneWithoutReviewsInput; }; +type ReviewFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type ReviewFactoryDefineOptions = { defaultData: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: ReviewFactoryTrait; }; -}; +} & CallbackDefineOptions; function isReviewpostFactory(x: ReviewpostFactory | Prisma.PostCreateNestedOneWithoutReviewsInput | undefined): x is ReviewpostFactory { return (x as any)?._factoryFor === "Post"; @@ -349,11 +393,23 @@ function autoGenerateReviewScalarsOrEnums({ seq }: { }; } -function defineReviewFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): ReviewFactoryInterface { +function defineReviewFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): ReviewFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly ReviewTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("Review", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateReviewScalarsOrEnums({ seq }); @@ -376,6 +432,7 @@ function defineReviewFactoryInternal[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -384,7 +441,10 @@ function defineReviewFactoryInternal = {}) => { const data = await build(inputData).then(screen); - return await getClient().review.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().review.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); diff --git a/packages/artifact-testing/fixtures/relations-one-to-one/__generated__/fabbrica/index.ts b/packages/artifact-testing/fixtures/relations-one-to-one/__generated__/fabbrica/index.ts index 044ca80..e238537 100644 --- a/packages/artifact-testing/fixtures/relations-one-to-one/__generated__/fabbrica/index.ts +++ b/packages/artifact-testing/fixtures/relations-one-to-one/__generated__/fabbrica/index.ts @@ -2,13 +2,19 @@ import type { User } from "../client"; import type { Profile } from "../client"; import { Prisma } from "../client"; import type { PrismaClient } from "../client"; -import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, } from "@quramy/prisma-fabbrica/lib/internal"; +import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, createCallbackChain, } from "@quramy/prisma-fabbrica/lib/internal"; export { resetSequence, registerScalarFieldValueGenerator, resetScalarFieldValueGenerator } from "@quramy/prisma-fabbrica/lib/internal"; type BuildDataOptions = { readonly seq: number; }; +type CallbackDefineOptions = { + onAfterBuild?: (createInput: TCreateInput) => void | PromiseLike; + onBeforeCreate?: (createInput: TCreateInput) => void | PromiseLike; + onAfterCreate?: (created: TCreated) => void | PromiseLike; +}; + const initializer = createInitializer(); const { getClient } = initializer; @@ -47,14 +53,16 @@ type UserFactoryDefineInput = { profile?: UserprofileFactory | Prisma.ProfileCreateNestedOneWithoutUserInput; }; +type UserFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type UserFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: UserFactoryTrait; }; -}; +} & CallbackDefineOptions; function isUserprofileFactory(x: UserprofileFactory | Prisma.ProfileCreateNestedOneWithoutUserInput | undefined): x is UserprofileFactory { return (x as any)?._factoryFor === "Profile"; @@ -86,11 +94,23 @@ function autoGenerateUserScalarsOrEnums({ seq }: { }; } -function defineUserFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { +function defineUserFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly UserTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("User", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateUserScalarsOrEnums({ seq }); @@ -110,6 +130,7 @@ function defineUserFactoryInternal({ } : defaultData.profile }; const data: Prisma.UserCreateInput = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -118,7 +139,10 @@ function defineUserFactoryInternal({ }); const create = async (inputData: Partial = {}) => { const data = await build(inputData).then(screen); - return await getClient().user.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().user.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); @@ -167,14 +191,16 @@ type ProfileFactoryDefineInput = { user: ProfileuserFactory | Prisma.UserCreateNestedOneWithoutProfileInput; }; +type ProfileFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type ProfileFactoryDefineOptions = { defaultData: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: ProfileFactoryTrait; }; -}; +} & CallbackDefineOptions; function isProfileuserFactory(x: ProfileuserFactory | Prisma.UserCreateNestedOneWithoutProfileInput | undefined): x is ProfileuserFactory { return (x as any)?._factoryFor === "User"; @@ -205,11 +231,23 @@ function autoGenerateProfileScalarsOrEnums({ seq }: { }; } -function defineProfileFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): ProfileFactoryInterface { +function defineProfileFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): ProfileFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly ProfileTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("Profile", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateProfileScalarsOrEnums({ seq }); @@ -229,6 +267,7 @@ function defineProfileFactoryInternal[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -237,7 +276,10 @@ function defineProfileFactoryInternal = {}) => { const data = await build(inputData).then(screen); - return await getClient().profile.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().profile.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); diff --git a/packages/artifact-testing/fixtures/simple-model/__generated__/fabbrica/index.ts b/packages/artifact-testing/fixtures/simple-model/__generated__/fabbrica/index.ts index 9e08528..de574e1 100644 --- a/packages/artifact-testing/fixtures/simple-model/__generated__/fabbrica/index.ts +++ b/packages/artifact-testing/fixtures/simple-model/__generated__/fabbrica/index.ts @@ -1,13 +1,19 @@ import type { User } from "../client"; import { Prisma } from "../client"; import type { PrismaClient } from "../client"; -import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, } from "@quramy/prisma-fabbrica/lib/internal"; +import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, createCallbackChain, } from "@quramy/prisma-fabbrica/lib/internal"; export { resetSequence, registerScalarFieldValueGenerator, resetScalarFieldValueGenerator } from "@quramy/prisma-fabbrica/lib/internal"; type BuildDataOptions = { readonly seq: number; }; +type CallbackDefineOptions = { + onAfterBuild?: (createInput: TCreateInput) => void | PromiseLike; + onBeforeCreate?: (createInput: TCreateInput) => void | PromiseLike; + onAfterCreate?: (created: TCreated) => void | PromiseLike; +}; + const initializer = createInitializer(); const { getClient } = initializer; @@ -29,14 +35,16 @@ type UserFactoryDefineInput = { name?: string; }; +type UserFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type UserFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: UserFactoryTrait; }; -}; +} & CallbackDefineOptions; type UserTraitKeys = keyof TOptions["traits"]; @@ -64,11 +72,23 @@ function autoGenerateUserScalarsOrEnums({ seq }: { }; } -function defineUserFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { +function defineUserFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly UserTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("User", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateUserScalarsOrEnums({ seq }); @@ -84,6 +104,7 @@ function defineUserFactoryInternal({ }, resolveValue({ seq })); const defaultAssociations = {}; const data: Prisma.UserCreateInput = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -92,7 +113,10 @@ function defineUserFactoryInternal({ }); const create = async (inputData: Partial = {}) => { const data = await build(inputData).then(screen); - return await getClient().user.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().user.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); diff --git a/packages/artifact-testing/fixtures/simple-model/callbacks.test.ts b/packages/artifact-testing/fixtures/simple-model/callbacks.test.ts new file mode 100644 index 0000000..5f8d517 --- /dev/null +++ b/packages/artifact-testing/fixtures/simple-model/callbacks.test.ts @@ -0,0 +1,107 @@ +import type { PrismaClient, User } from "./__generated__/client"; + +import { initialize, defineUserFactory } from "./__generated__/fabbrica"; + +describe("Generated functions", () => { + beforeAll(() => { + const clientStub = { + user: { create: jest.fn().mockReturnValue({ id: "stub id", name: "stub name" } as User) }, + } as unknown as PrismaClient; + initialize({ prisma: clientStub }); + }); + + describe("callback", () => { + test("onAfterBuild", async () => { + const mock = jest.fn(); + const UserFactory = defineUserFactory({ + onAfterBuild: user => { + mock(user); + }, + }); + + await UserFactory.build({ id: "id", name: "name" }); + + expect(mock).toBeCalledTimes(1); + expect(mock).toBeCalledWith({ id: "id", name: "name" }); + }); + + test("onBeforeCreate", async () => { + const mock = jest.fn(); + const UserFactory = defineUserFactory({ + onBeforeCreate: user => { + mock(user); + }, + }); + + await UserFactory.create({ id: "id", name: "name" }); + + expect(mock).toBeCalledTimes(1); + expect(mock).toBeCalledWith({ id: "id", name: "name" }); + }); + + test("onAfterCreate", async () => { + const mock = jest.fn(); + const UserFactory = defineUserFactory({ + onAfterCreate: user => { + mock(user); + }, + }); + + await UserFactory.create(); + + expect(mock).toBeCalledTimes(1); + expect(mock).toBeCalledWith({ id: "stub id", name: "stub name" }); + }); + + test("callback orders with traits", async () => { + const mock = jest.fn(); + const UserFactory = defineUserFactory({ + onAfterBuild: () => { + mock("factory default", "onAfterBuild"); + }, + onBeforeCreate: () => { + mock("factory default", "onBeforeCreate"); + }, + onAfterCreate: () => { + mock("factory default", "onAfterCreate"); + }, + traits: { + a: { + onAfterBuild: () => { + mock("trait a", "onAfterBuild"); + }, + onBeforeCreate: () => { + mock("trait a", "onBeforeCreate"); + }, + onAfterCreate: () => { + mock("trait a", "onAfterCreate"); + }, + }, + b: { + onAfterBuild: () => { + mock("trait b", "onAfterBuild"); + }, + onBeforeCreate: () => { + mock("trait b", "onBeforeCreate"); + }, + onAfterCreate: () => { + mock("trait b", "onAfterCreate"); + }, + }, + }, + }); + + await UserFactory.use("a", "b").create(); + + expect(mock).toHaveBeenNthCalledWith(1, "factory default", "onAfterBuild"); + expect(mock).toHaveBeenNthCalledWith(2, "trait a", "onAfterBuild"); + expect(mock).toHaveBeenNthCalledWith(3, "trait b", "onAfterBuild"); + expect(mock).toHaveBeenNthCalledWith(4, "trait b", "onBeforeCreate"); + expect(mock).toHaveBeenNthCalledWith(5, "trait a", "onBeforeCreate"); + expect(mock).toHaveBeenNthCalledWith(6, "factory default", "onBeforeCreate"); + expect(mock).toHaveBeenNthCalledWith(7, "factory default", "onAfterCreate"); + expect(mock).toHaveBeenNthCalledWith(8, "trait a", "onAfterCreate"); + expect(mock).toHaveBeenNthCalledWith(9, "trait b", "onAfterCreate"); + }); + }); +}); diff --git a/packages/prisma-fabbrica/src/helpers/callback.ts b/packages/prisma-fabbrica/src/helpers/callback.ts new file mode 100644 index 0000000..e4797f6 --- /dev/null +++ b/packages/prisma-fabbrica/src/helpers/callback.ts @@ -0,0 +1,11 @@ +export type CallbackFn = (...args: T) => unknown; + +export function createCallbackChain(callbackFns: readonly (CallbackFn | undefined)[]) { + return async (...args: T) => { + await callbackFns.reduce(async (acc, fn) => { + await acc; + if (!fn) return; + await fn(...args); + }, Promise.resolve()); + }; +} diff --git a/packages/prisma-fabbrica/src/helpers/index.ts b/packages/prisma-fabbrica/src/helpers/index.ts index 21004d2..6ac7f30 100644 --- a/packages/prisma-fabbrica/src/helpers/index.ts +++ b/packages/prisma-fabbrica/src/helpers/index.ts @@ -3,3 +3,4 @@ export * from "./stringConverter"; export * from "./sequence"; export * from "./selectors"; export * from "./list"; +export * from "./callback"; diff --git a/packages/prisma-fabbrica/src/internal.ts b/packages/prisma-fabbrica/src/internal.ts index 88212b0..425f311 100644 --- a/packages/prisma-fabbrica/src/internal.ts +++ b/packages/prisma-fabbrica/src/internal.ts @@ -1,6 +1,6 @@ export { getClient } from "./clientHolder"; export { ModelWithFields, createScreener } from "./relations/screen"; -export { Resolver, normalizeResolver, normalizeList, getSequenceCounter } from "./helpers"; +export { Resolver, normalizeResolver, normalizeList, getSequenceCounter, createCallbackChain } from "./helpers"; export { createInitializer, initialize, resetSequence } from "./initialize"; export { getScalarFieldValueGenerator, diff --git a/packages/prisma-fabbrica/src/templates/__snapshots__/getSourceFile.test.ts.snap b/packages/prisma-fabbrica/src/templates/__snapshots__/getSourceFile.test.ts.snap index 295f8b8..1a4d830 100644 --- a/packages/prisma-fabbrica/src/templates/__snapshots__/getSourceFile.test.ts.snap +++ b/packages/prisma-fabbrica/src/templates/__snapshots__/getSourceFile.test.ts.snap @@ -4,13 +4,19 @@ exports[`getSourceFile generates TypeScript AST 1`] = ` "import type { User } from "@prisma/client"; import { Prisma } from "@prisma/client"; import type { PrismaClient } from "@prisma/client"; -import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, } from "@quramy/prisma-fabbrica/lib/internal"; +import { createInitializer, ModelWithFields, createScreener, getScalarFieldValueGenerator, Resolver, normalizeResolver, normalizeList, getSequenceCounter, createCallbackChain, } from "@quramy/prisma-fabbrica/lib/internal"; export { resetSequence, registerScalarFieldValueGenerator, resetScalarFieldValueGenerator } from "@quramy/prisma-fabbrica/lib/internal"; type BuildDataOptions = { readonly seq: number; }; +type CallbackDefineOptions = { + onAfterBuild?: (createInput: TCreateInput) => void | PromiseLike; + onBeforeCreate?: (createInput: TCreateInput) => void | PromiseLike; + onAfterCreate?: (created: TCreated) => void | PromiseLike; +}; + const initializer = createInitializer(); const { getClient } = initializer; @@ -32,14 +38,16 @@ type UserFactoryDefineInput = { name?: string; }; +type UserFactoryTrait = { + data?: Resolver, BuildDataOptions>; +} & CallbackDefineOptions; + type UserFactoryDefineOptions = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - }; + [traitName: string | symbol]: UserFactoryTrait; }; -}; +} & CallbackDefineOptions; type UserTraitKeys = keyof TOptions["traits"]; @@ -67,11 +75,23 @@ function autoGenerateUserScalarsOrEnums({ seq }: { }; } -function defineUserFactoryInternal({ defaultData: defaultDataResolver, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { +function defineUserFactoryInternal({ defaultData: defaultDataResolver, onAfterBuild, onBeforeCreate, onAfterCreate, traits: traitsDefs = {} }: TOptions): UserFactoryInterface { const getFactoryWithTraits = (traitKeys: readonly UserTraitKeys[] = []) => { const seqKey = {}; const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener("User", modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys.slice().reverse().map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); const build = async (inputData: Partial = {}) => { const seq = getSeq(); const requiredScalarData = autoGenerateUserScalarsOrEnums({ seq }); @@ -87,6 +107,7 @@ function defineUserFactoryInternal({ }, resolveValue({ seq })); const defaultAssociations = {}; const data: Prisma.UserCreateInput = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData }; + await handleAfterBuild(data); return data; }; const buildList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => build(data))); @@ -95,7 +116,10 @@ function defineUserFactoryInternal({ }); const create = async (inputData: Partial = {}) => { const data = await build(inputData).then(screen); - return await getClient().user.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().user.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = (inputData: number | readonly Partial[]) => Promise.all(normalizeList(inputData).map(data => create(data))); const createForConnect = (inputData: Partial = {}) => create(inputData).then(pickForConnect); diff --git a/packages/prisma-fabbrica/src/templates/index.ts b/packages/prisma-fabbrica/src/templates/index.ts index a48de6d..f81ad8c 100644 --- a/packages/prisma-fabbrica/src/templates/index.ts +++ b/packages/prisma-fabbrica/src/templates/index.ts @@ -89,24 +89,28 @@ export const header = (prismaClientModuleSpecifier: string) => normalizeResolver, normalizeList, getSequenceCounter, + createCallbackChain, } from "@quramy/prisma-fabbrica/lib/internal"; export { resetSequence, registerScalarFieldValueGenerator, resetScalarFieldValueGenerator } from "@quramy/prisma-fabbrica/lib/internal"; `(); -export const buildDataOptions = () => - template.statement` - type BuildDataOptions = { - readonly seq: number; - }; - `(); - export const importStatement = (specifier: string, prismaClientModuleSpecifier: string) => template.statement` import type { ${() => ast.identifier(specifier)} } from ${() => ast.stringLiteral(prismaClientModuleSpecifier)}; `(); -export const initializer = () => +export const genericDeclarations = () => template.sourceFile` + type BuildDataOptions = { + readonly seq: number; + }; + + type CallbackDefineOptions = { + onAfterBuild?: (createInput: TCreateInput) => void | PromiseLike; + onBeforeCreate?: (createInput: TCreateInput) => void | PromiseLike; + onAfterCreate?: (created: TCreated) => void | PromiseLike; + }; + const initializer = createInitializer(); const { getClient } = initializer; export const { initialize } = initializer; @@ -226,31 +230,42 @@ export const modelFactoryDefineInput = (model: DMMF.Model, inputType: DMMF.Input MODEL_FACTORY_DEFINE_INPUT: ast.identifier(`${model.name}FactoryDefineInput`), }); -export const modelFactoryDefineOptions = (modelName: string, isOpionalDefaultData: boolean) => { +export const modelFactoryTrait = (model: DMMF.Model) => + template.statement` + type MODEL_FACTORY_TRAIT = { + data?: Resolver, BuildDataOptions>; + } & CallbackDefineOptions; + `({ + MODEL_TYPE: ast.identifier(model.name), + MODEL_CREATE_INPUT: ast.identifier(`${model.name}CreateInput`), + MODEL_FACTORY_DEFINE_INPUT: ast.identifier(`${model.name}FactoryDefineInput`), + MODEL_FACTORY_TRAIT: ast.identifier(`${model.name}FactoryTrait`), + }); + +export const modelFactoryDefineOptions = (model: DMMF.Model, isOpionalDefaultData: boolean) => { const compiled = isOpionalDefaultData ? template.statement` type MODEL_FACTORY_DEFINE_OPTIONS = { defaultData?: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - } + [traitName: string | symbol]: MODEL_FACTORY_TRAIT; }; - }; + } & CallbackDefineOptions; ` : template.statement` type MODEL_FACTORY_DEFINE_OPTIONS = { defaultData: Resolver; traits?: { - [traitName: string | symbol]: { - data: Resolver, BuildDataOptions>; - } + [traitName: string | symbol]: MODEL_FACTORY_TRAIT; }; - }; + } & CallbackDefineOptions; `; return compiled({ - MODEL_FACTORY_DEFINE_OPTIONS: ast.identifier(`${modelName}FactoryDefineOptions`), - MODEL_FACTORY_DEFINE_INPUT: ast.identifier(`${modelName}FactoryDefineInput`), + MODEL_TYPE: ast.identifier(model.name), + MODEL_CREATE_INPUT: ast.identifier(`${model.name}CreateInput`), + MODEL_FACTORY_DEFINE_INPUT: ast.identifier(`${model.name}FactoryDefineInput`), + MODEL_FACTORY_TRAIT: ast.identifier(`${model.name}FactoryTrait`), + MODEL_FACTORY_DEFINE_OPTIONS: ast.identifier(`${model.name}FactoryDefineOptions`), }); }; @@ -357,6 +372,9 @@ export const defineModelFactoryInternal = (model: DMMF.Model, inputType: DMMF.In template.statement` function DEFINE_MODEL_FACTORY_INTERNAL({ defaultData: defaultDataResolver, + onAfterBuild, + onBeforeCreate, + onAfterCreate, traits: traitsDefs = {} }: TOptions): MODEL_FACTORY_INTERFACE { const getFactoryWithTraits = (traitKeys: readonly MODEL_TRAIT_KEYS[] = []) => { @@ -364,6 +382,22 @@ export const defineModelFactoryInternal = (model: DMMF.Model, inputType: DMMF.In const getSeq = () => getSequenceCounter(seqKey); const screen = createScreener(${() => ast.stringLiteral(model.name)}, modelFieldDefinitions); + const handleAfterBuild = createCallbackChain([ + onAfterBuild, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterBuild), + ]); + const handleBeforeCreate = createCallbackChain([ + ...traitKeys + .slice() + .reverse() + .map(traitKey => traitsDefs[traitKey].onBeforeCreate), + onBeforeCreate, + ]); + const handleAfterCreate = createCallbackChain([ + onAfterCreate, + ...traitKeys.map(traitKey => traitsDefs[traitKey].onAfterCreate), + ]); + const build = async ( inputData: Partial = {} ) => { @@ -397,6 +431,7 @@ export const defineModelFactoryInternal = (model: DMMF.Model, inputType: DMMF.In true, )}; const data: Prisma.MODEL_CREATE_INPUT = { ...requiredScalarData, ...defaultData, ...defaultAssociations, ...inputData}; + await handleAfterBuild(data); return data; }; @@ -418,7 +453,10 @@ export const defineModelFactoryInternal = (model: DMMF.Model, inputType: DMMF.In inputData: Partial = {} ) => { const data = await build(inputData).then(screen); - return await getClient().MODEL_KEY.create({ data }); + await handleBeforeCreate(data); + const createdData = await getClient().MODEL_KEY.create({ data }); + await handleAfterCreate(createdData); + return createdData; }; const createList = ( @@ -514,8 +552,7 @@ export function getSourceFile({ ...modelNames.map(modelName => importStatement(modelName, prismaClientModuleSpecifier)), ...modelEnums.map(enumName => importStatement(enumName, prismaClientModuleSpecifier)), ...header(prismaClientModuleSpecifier).statements, - insertLeadingBreakMarker(buildDataOptions()), - ...insertLeadingBreakMarker(initializer().statements), + ...insertLeadingBreakMarker(genericDeclarations().statements), insertLeadingBreakMarker(modelFieldDefinitions(document.datamodel.models)), ...document.datamodel.models .reduce( @@ -532,7 +569,8 @@ export function getSourceFile({ modelBelongsToRelationFactory(fieldType, model), ), modelFactoryDefineInput(model, createInputType), - modelFactoryDefineOptions(model.name, filterRequiredInputObjectTypeField(createInputType).length === 0), + modelFactoryTrait(model), + modelFactoryDefineOptions(model, filterRequiredInputObjectTypeField(createInputType).length === 0), ...filterBelongsToField(model, createInputType).map(fieldType => isModelAssociationFactory(fieldType, model)), modelTraitKeys(model), modelFactoryInterfaceWithoutTraits(model),