Skip to content

Commit

Permalink
Merge branch 'main' of github.com:drizzle-team/drizzle-orm into pgvector
Browse files Browse the repository at this point in the history
  • Loading branch information
AndriiSherman committed Apr 29, 2024
2 parents 9185563 + e0aaeb2 commit 68815d1
Show file tree
Hide file tree
Showing 49 changed files with 3,286 additions and 2,443 deletions.
1 change: 1 addition & 0 deletions .github/workflows/release-feature-branch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ jobs:
PLANETSCALE_CONNECTION_STRING: ${{ secrets.PLANETSCALE_CONNECTION_STRING }}
NEON_CONNECTION_STRING: ${{ secrets.NEON_CONNECTION_STRING }}
XATA_API_KEY: ${{ secrets.XATA_API_KEY }}
XATA_BRANCH: ${{ secrets.XATA_BRANCH }}
LIBSQL_URL: file:local.db
run: |
if [[ ${{ github.event_name }} != "push" && "${{ github.event.pull_request.head.repo.full_name }}" != "${{ github.repository }}" ]]; then
Expand Down
16 changes: 16 additions & 0 deletions changelogs/drizzle-orm/0.30.8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
- 🎉 Added custom schema support to enums in Postgres (fixes #669 via #2048):

```ts
import { pgSchema } from 'drizzle-orm/pg-core';

const mySchema = pgSchema('mySchema');
const colors = mySchema.enum('colors', ['red', 'green', 'blue']);
```

- 🎉 Changed D1 `migrate()` function to use batch API (#2137)
- 🐛 Split `where` clause in Postgres `.onConflictDoUpdate` method into `setWhere` and `targetWhere` clauses, to support both `where` cases in `on conflict ...` clause (fixes #1628, #1302 via #2056)
- 🐛 Fixed query generation for `where` clause in Postgres `.onConflictDoNothing` method, as it was placed in a wrong spot (fixes #1628 via #2056)
- 🐛 Fixed multiple issues with AWS Data API driver (fixes #1931, #1932, #1934, #1936 via #2119)
- 🐛 Fix inserting and updating array values in AWS Data API (fixes #1912 via #1911)

Thanks @hugo082 and @livingforjesus!
3 changes: 3 additions & 0 deletions changelogs/drizzle-orm/0.30.9.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- 🐛 Fixed migrator in AWS Data API
- Added `setWhere` and `targetWhere` fields to `.onConflictDoUpdate()` config in SQLite instead of single `where` field
- 🛠️ Added schema information to Drizzle instances via `db._.fullSchema`
3 changes: 2 additions & 1 deletion dprint.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"**/drizzle2/**/meta",
"**/*snapshot.json",
"**/_journal.json",
"**/tsup.config*.mjs"
"**/tsup.config*.mjs",
"**/.sst"
],
"plugins": [
"https://plugins.dprint.dev/typescript-0.83.0.wasm",
Expand Down
4 changes: 2 additions & 2 deletions drizzle-orm/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "drizzle-orm",
"version": "0.30.7",
"version": "0.30.9",
"description": "Drizzle ORM package for SQL databases",
"type": "module",
"scripts": {
Expand Down Expand Up @@ -147,7 +147,7 @@
}
},
"devDependencies": {
"@aws-sdk/client-rds-data": "^3.344.0",
"@aws-sdk/client-rds-data": "^3.549.0",
"@cloudflare/workers-types": "^4.20230904.0",
"@electric-sql/pglite": "^0.1.1",
"@libsql/client": "^0.5.6",
Expand Down
13 changes: 13 additions & 0 deletions drizzle-orm/src/aws-data-api/common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ export function getValueFromDataApi(field: Field) {
if (field.arrayValue.stringValues !== undefined) {
return field.arrayValue.stringValues;
}
if (field.arrayValue.longValues !== undefined) {
return field.arrayValue.longValues;
}
if (field.arrayValue.doubleValues !== undefined) {
return field.arrayValue.doubleValues;
}
if (field.arrayValue.booleanValues !== undefined) {
return field.arrayValue.booleanValues;
}
if (field.arrayValue.arrayValues !== undefined) {
return field.arrayValue.arrayValues;
}

throw new Error('Unknown array type');
} else {
throw new Error('Unknown type');
Expand Down
62 changes: 57 additions & 5 deletions drizzle-orm/src/aws-data-api/pg/driver.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import { entityKind } from '~/entity.ts';
import { entityKind, is } from '~/entity.ts';
import type { SQL, SQLWrapper } from '~/index.ts';
import { Param, sql, Table } from '~/index.ts';
import type { Logger } from '~/logger.ts';
import { DefaultLogger } from '~/logger.ts';
import { PgDatabase } from '~/pg-core/db.ts';
import { PgDialect } from '~/pg-core/dialect.ts';
import type { PgColumn, PgInsertConfig, PgTable, TableConfig } from '~/pg-core/index.ts';
import { PgArray } from '~/pg-core/index.ts';
import type { PgRaw } from '~/pg-core/query-builders/raw.ts';
import {
createTableRelationsHelpers,
extractTablesRelationalConfig,
type RelationalSchemaConfig,
type TablesRelationalConfig,
} from '~/relations.ts';
import type { DrizzleConfig } from '~/utils.ts';
import type { AwsDataApiClient, AwsDataApiPgQueryResultHKT } from './session.ts';
import type { DrizzleConfig, UpdateSet } from '~/utils.ts';
import type { AwsDataApiClient, AwsDataApiPgQueryResult, AwsDataApiPgQueryResultHKT } from './session.ts';
import { AwsDataApiSession } from './session.ts';

export interface PgDriverOptions {
Expand All @@ -28,16 +33,63 @@ export interface DrizzleAwsDataApiPgConfig<
secretArn: string;
}

export type AwsDataApiPgDatabase<
export class AwsDataApiPgDatabase<
TSchema extends Record<string, unknown> = Record<string, never>,
> = PgDatabase<AwsDataApiPgQueryResultHKT, TSchema>;
> extends PgDatabase<AwsDataApiPgQueryResultHKT, TSchema> {
static readonly [entityKind]: string = 'AwsDataApiPgDatabase';

override execute<
TRow extends Record<string, unknown> = Record<string, unknown>,
>(query: SQLWrapper): PgRaw<AwsDataApiPgQueryResult<TRow>> {
return super.execute(query);
}
}

export class AwsPgDialect extends PgDialect {
static readonly [entityKind]: string = 'AwsPgDialect';

override escapeParam(num: number): string {
return `:${num + 1}`;
}

override buildInsertQuery(
{ table, values, onConflict, returning }: PgInsertConfig<PgTable<TableConfig>>,
): SQL<unknown> {
const columns: Record<string, PgColumn> = table[Table.Symbol.Columns];
const colEntries: [string, PgColumn][] = Object.entries(columns);
for (const value of values) {
for (const [fieldName, col] of colEntries) {
const colValue = value[fieldName];
if (
is(colValue, Param) && colValue.value !== undefined && is(colValue.encoder, PgArray)
&& Array.isArray(colValue.value)
) {
value[fieldName] = sql`cast(${col.mapToDriverValue(colValue.value)} as ${
sql.raw(colValue.encoder.getSQLType())
})`;
}
}
}

return super.buildInsertQuery({ table, values, onConflict, returning });
}

override buildUpdateSet(table: PgTable<TableConfig>, set: UpdateSet): SQL<unknown> {
const columns: Record<string, PgColumn> = table[Table.Symbol.Columns];

for (const [colName, colValue] of Object.entries(set)) {
const currentColumn = columns[colName];
if (
currentColumn && is(colValue, Param) && colValue.value !== undefined && is(colValue.encoder, PgArray)
&& Array.isArray(colValue.value)
) {
set[colName] = sql`cast(${currentColumn?.mapToDriverValue(colValue.value)} as ${
sql.raw(colValue.encoder.getSQLType())
})`;
}
}
return super.buildUpdateSet(table, set);
}
}

export function drizzle<TSchema extends Record<string, unknown> = Record<string, never>>(
Expand Down
65 changes: 43 additions & 22 deletions drizzle-orm/src/aws-data-api/pg/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,18 @@ export class AwsDataApiPreparedQuery<T extends PreparedQueryConfig> extends PgPr
async execute(placeholderValues: Record<string, unknown> | undefined = {}): Promise<T['execute']> {
const { fields, joinsNotNullableMap, customResultMapper } = this;

const rows = await this.values(placeholderValues) as unknown[][];
const result = await this.values(placeholderValues) as AwsDataApiPgQueryResult<unknown[]>;
if (!fields && !customResultMapper) {
return rows as T['execute'];
return result as T['execute'];
}
return customResultMapper
? customResultMapper(rows)
: rows.map((row) => mapResultRow<T['execute']>(fields!, row, joinsNotNullableMap));
? customResultMapper(result.rows!)
: result.rows!.map((row) => mapResultRow<T['execute']>(fields!, row, joinsNotNullableMap));
}

all(placeholderValues?: Record<string, unknown> | undefined): Promise<T['all']> {
return this.execute(placeholderValues);
async all(placeholderValues?: Record<string, unknown> | undefined): Promise<T['all']> {
const result = await this.execute(placeholderValues) as AwsDataApiPgQueryResult<unknown>;
return result.rows;
}

async values(placeholderValues: Record<string, unknown> = {}): Promise<T['values']> {
Expand All @@ -83,16 +84,24 @@ export class AwsDataApiPreparedQuery<T extends PreparedQueryConfig> extends PgPr
if (!fields && !customResultMapper) {
const result = await client.send(rawQuery);
if (result.columnMetadata && result.columnMetadata.length > 0) {
return this.mapResultRows(result.records ?? [], result.columnMetadata);
const rows = this.mapResultRows(result.records ?? [], result.columnMetadata);
return {
...result,
rows,
};
}
return result.records ?? [];
return result;
}

const result = await client.send(rawQuery);
const rows = result.records?.map((row) => {
return row.map((field) => getValueFromDataApi(field));
}) ?? [];

return result.records?.map((row: any) => {
return row.map((field: Field) => getValueFromDataApi(field));
});
return {
...result,
rows,
};
}

/** @internal */
Expand Down Expand Up @@ -155,9 +164,10 @@ export class AwsDataApiSession<
prepareQuery<T extends PreparedQueryConfig = PreparedQueryConfig>(
query: QueryWithTypings,
fields: SelectedFieldsOrdered | undefined,
transactionId: string | undefined,
name: string | undefined,
isResponseInArrayMode: boolean,
customResultMapper?: (rows: unknown[][]) => T['execute'],
transactionId?: string,
): PgPreparedQuery<T> {
return new AwsDataApiPreparedQuery(
this.client,
Expand All @@ -166,7 +176,7 @@ export class AwsDataApiSession<
query.typings ?? [],
this.options,
fields,
transactionId,
transactionId ?? this.transactionId,
isResponseInArrayMode,
customResultMapper,
);
Expand All @@ -176,8 +186,10 @@ export class AwsDataApiSession<
return this.prepareQuery<PreparedQueryConfig & { execute: T }>(
this.dialect.sqlToQuery(query),
undefined,
this.transactionId,
undefined,
false,
undefined,
this.transactionId,
).execute();
}

Expand All @@ -187,7 +199,7 @@ export class AwsDataApiSession<
): Promise<T> {
const { transactionId } = await this.client.send(new BeginTransactionCommand(this.rawQuery));
const session = new AwsDataApiSession(this.client, this.dialect, this.schema, this.options, transactionId);
const tx = new AwsDataApiTransaction(this.dialect, session, this.schema);
const tx = new AwsDataApiTransaction<TFullSchema, TSchema>(this.dialect, session, this.schema);
if (config) {
await tx.setTransaction(config);
}
Expand All @@ -208,21 +220,30 @@ export class AwsDataApiTransaction<
> extends PgTransaction<AwsDataApiPgQueryResultHKT, TFullSchema, TSchema> {
static readonly [entityKind]: string = 'AwsDataApiTransaction';

override transaction<T>(transaction: (tx: AwsDataApiTransaction<TFullSchema, TSchema>) => Promise<T>): Promise<T> {
override async transaction<T>(
transaction: (tx: AwsDataApiTransaction<TFullSchema, TSchema>) => Promise<T>,
): Promise<T> {
const savepointName = `sp${this.nestedIndex + 1}`;
const tx = new AwsDataApiTransaction(this.dialect, this.session, this.schema, this.nestedIndex + 1);
this.session.execute(sql`savepoint ${savepointName}`);
const tx = new AwsDataApiTransaction<TFullSchema, TSchema>(
this.dialect,
this.session,
this.schema,
this.nestedIndex + 1,
);
await this.session.execute(sql.raw(`savepoint ${savepointName}`));
try {
const result = transaction(tx);
this.session.execute(sql`release savepoint ${savepointName}`);
const result = await transaction(tx);
await this.session.execute(sql.raw(`release savepoint ${savepointName}`));
return result;
} catch (e) {
this.session.execute(sql`rollback to savepoint ${savepointName}`);
await this.session.execute(sql.raw(`rollback to savepoint ${savepointName}`));
throw e;
}
}
}

export type AwsDataApiPgQueryResult<T> = ExecuteStatementCommandOutput & { rows: T[] };

export interface AwsDataApiPgQueryResultHKT extends QueryResultHKT {
type: ExecuteStatementCommandOutput;
type: AwsDataApiPgQueryResult<any>;
}
2 changes: 1 addition & 1 deletion drizzle-orm/src/d1/driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import { SQLiteD1Session } from './session.ts';
export class DrizzleD1Database<
TSchema extends Record<string, unknown> = Record<string, never>,
> extends BaseSQLiteDatabase<'async', D1Result, TSchema> {
static readonly [entityKind]: string = 'LibSQLDatabase';
static readonly [entityKind]: string = 'D1Database';

/** @internal */
declare readonly session: SQLiteD1Session<TSchema, ExtractTablesWithRelations<TSchema>>;
Expand Down
44 changes: 43 additions & 1 deletion drizzle-orm/src/d1/migrator.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,53 @@
import type { MigrationConfig } from '~/migrator.ts';
import { readMigrationFiles } from '~/migrator.ts';
import { sql } from '~/sql/sql.ts';
import type { DrizzleD1Database } from './driver.ts';

export async function migrate<TSchema extends Record<string, unknown>>(
db: DrizzleD1Database<TSchema>,
config: string | MigrationConfig,
) {
const migrations = readMigrationFiles(config);
await db.dialect.migrate(migrations, db.session, config);
const migrationsTable = config === undefined
? '__drizzle_migrations'
: typeof config === 'string'
? '__drizzle_migrations'
: config.migrationsTable ?? '__drizzle_migrations';

const migrationTableCreate = sql`
CREATE TABLE IF NOT EXISTS ${sql.identifier(migrationsTable)} (
id SERIAL PRIMARY KEY,
hash text NOT NULL,
created_at numeric
)
`;
await db.session.run(migrationTableCreate);

const dbMigrations = await db.values<[number, string, string]>(
sql`SELECT id, hash, created_at FROM ${sql.identifier(migrationsTable)} ORDER BY created_at DESC LIMIT 1`,
);

const lastDbMigration = dbMigrations[0] ?? undefined;

const statementToBatch = [];

for (const migration of migrations) {
if (!lastDbMigration || Number(lastDbMigration[2])! < migration.folderMillis) {
for (const stmt of migration.sql) {
statementToBatch.push(db.run(sql.raw(stmt)));
}

statementToBatch.push(
db.run(
sql`INSERT INTO ${sql.identifier(migrationsTable)} ("hash", "created_at") VALUES(${
sql.raw(`'${migration.hash}'`)
}, ${sql.raw(`${migration.folderMillis}`)})`,
),
);
}
}

if (statementToBatch.length > 0) {
await db.session.batch(statementToBatch);
}
}
2 changes: 1 addition & 1 deletion drizzle-orm/src/d1/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export class SQLiteD1Session<
);
}

async batch<U extends BatchItem, T extends Readonly<[U, ...U[]]>>(queries: T) {
async batch<T extends BatchItem<'sqlite'>[] | readonly BatchItem<'sqlite'>[]>(queries: T) {
const preparedQueries: PreparedQuery[] = [];
const builtQueries: D1PreparedStatement[] = [];

Expand Down
10 changes: 8 additions & 2 deletions drizzle-orm/src/libsql/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,14 @@ export class LibSQLSession<
): Promise<T> {
// TODO: support transaction behavior
const libsqlTx = await this.client.transaction();
const session = new LibSQLSession(this.client, this.dialect, this.schema, this.options, libsqlTx);
const tx = new LibSQLTransaction('async', this.dialect, session, this.schema);
const session = new LibSQLSession<TFullSchema, TSchema>(
this.client,
this.dialect,
this.schema,
this.options,
libsqlTx,
);
const tx = new LibSQLTransaction<TFullSchema, TSchema>('async', this.dialect, session, this.schema);
try {
const result = await transaction(tx);
await libsqlTx.commit();
Expand Down

0 comments on commit 68815d1

Please sign in to comment.