mirror of
https://github.com/typeorm/typeorm.git
synced 2025-12-08 21:26:23 +00:00
feat(postgres): support vector/halfvec data types (#11437)
This commit is contained in:
parent
96ea431eb7
commit
a49f612289
8
.github/workflows/tests-linux.yml
vendored
8
.github/workflows/tests-linux.yml
vendored
@ -253,13 +253,11 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
postgis-version:
|
||||
- "14-3.5"
|
||||
- "17-3.5"
|
||||
postgres-version: ["14", "17"]
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgis/postgis:${{ matrix.postgis-version }}
|
||||
image: ghcr.io/naorpeled/typeorm-postgres:pg${{ matrix.postgres-version }}-postgis3-pgvectorv0.8.0
|
||||
ports:
|
||||
- "5432:5432"
|
||||
env:
|
||||
@ -288,7 +286,7 @@ jobs:
|
||||
- name: Coveralls Parallel
|
||||
uses: coverallsapp/github-action@v2
|
||||
with:
|
||||
flag-name: postgres:${{ matrix.postgis-version }}-node:${{ inputs.node-version }}
|
||||
flag-name: postgres:${{ matrix.postgres-version }}-node:${{ inputs.node-version }}
|
||||
parallel: true
|
||||
|
||||
oracle:
|
||||
|
||||
@ -49,7 +49,7 @@ services:
|
||||
postgres-14:
|
||||
# postgis is postgres + PostGIS (only). If you need additional extensions,
|
||||
# it's probably time to create a purpose-built image.
|
||||
image: "postgis/postgis:14-3.5"
|
||||
image: "ghcr.io/naorpeled/typeorm-postgres:pg14-postgis3-pgvectorv0.8.0"
|
||||
container_name: "typeorm-postgres-14"
|
||||
ports:
|
||||
- "5432:5432"
|
||||
@ -62,7 +62,7 @@ services:
|
||||
postgres-17:
|
||||
# postgis is postgres + PostGIS (only). If you need additional extensions,
|
||||
# it's probably time to create a purpose-built image.
|
||||
image: "postgis/postgis:17-3.5"
|
||||
image: "ghcr.io/naorpeled/typeorm-postgres:pg17-postgis3-pgvectorv0.8.0"
|
||||
container_name: "typeorm-postgres-17"
|
||||
ports:
|
||||
- "5432:5432"
|
||||
|
||||
@ -37,8 +37,36 @@ SAP HANA 2.0 and SAP HANA Cloud support slightly different data types. Check the
|
||||
- [SAP HANA 2.0 Data Types](https://help.sap.com/docs/SAP_HANA_PLATFORM/4fe29514fd584807ac9f2a04f6754767/20a1569875191014b507cf392724b7eb.html?locale=en-US)
|
||||
- [SAP HANA Cloud Data Types](https://help.sap.com/docs/hana-cloud-database/sap-hana-cloud-sap-hana-database-sql-reference-guide/data-types)
|
||||
|
||||
TypeORM's `SapDriver` supports `tinyint`, `smallint`, `integer`, `bigint`, `smalldecimal`, `decimal`, `real`, `double`, `date`, `time`, `seconddate`, `timestamp`, `boolean`, `char`, `nchar`, `varchar`, `nvarchar`, `text`, `alphanum`, `shorttext`, `array`, `varbinary`, `blob`, `clob`, `nclob`, `st_geometry`, `st_point`, `real_vector` and `half_vector`. Some of these data types have been deprecated or removed in SAP HANA Cloud, and will be converted to the closest available alternative when connected to a Cloud database.
|
||||
TypeORM's `SapDriver` supports `tinyint`, `smallint`, `integer`, `bigint`, `smalldecimal`, `decimal`, `real`, `double`, `date`, `time`, `seconddate`, `timestamp`, `boolean`, `char`, `nchar`, `varchar`, `nvarchar`, `text`, `alphanum`, `shorttext`, `array`, `varbinary`, `blob`, `clob`, `nclob`, `st_geometry`, `st_point`, `real_vector`, `half_vector`, `vector`, and `halfvec`. Some of these data types have been deprecated or removed in SAP HANA Cloud, and will be converted to the closest available alternative when connected to a Cloud database.
|
||||
|
||||
### Vector Types
|
||||
|
||||
The `real_vector` and `half_vector` data types were introduced in SAP HANA Cloud (2024Q1 and 2025Q2 respectively), and require a supported version of `@sap/hana-client` as well. By default, the client will return a `Buffer` in the `fvecs`/`hvecs` format, which is more efficient. It is possible to let the driver convert the values to a `number[]` by adding `{ extra: { vectorOutputType: "Array" } }` to the connection options. Check the SAP HANA Client documentation for more information about [REAL_VECTOR](https://help.sap.com/docs/SAP_HANA_CLIENT/f1b440ded6144a54ada97ff95dac7adf/0d197e4389c64e6b9cf90f6f698f62fe.html) or [HALF_VECTOR](https://help.sap.com/docs/SAP_HANA_CLIENT/f1b440ded6144a54ada97ff95dac7adf/8bb854b4ce4a4299bed27c365b717e91.html).
|
||||
The `real_vector` and `half_vector` data types were introduced in SAP HANA Cloud (2024Q1 and 2025Q2 respectively), and require a supported version of `@sap/hana-client` as well.
|
||||
|
||||
For consistency with PostgreSQL's vector support, TypeORM also provides aliases:
|
||||
- `vector` (alias for `real_vector`) - stores vectors as 4-byte floats
|
||||
- `halfvec` (alias for `half_vector`) - stores vectors as 2-byte floats for memory efficiency
|
||||
|
||||
```typescript
|
||||
@Entity()
|
||||
export class Document {
|
||||
@PrimaryGeneratedColumn()
|
||||
id: number
|
||||
|
||||
// Using SAP HANA native type names
|
||||
@Column("real_vector", { length: 1536 })
|
||||
embedding: Buffer | number[]
|
||||
|
||||
@Column("half_vector", { length: 768 })
|
||||
reduced_embedding: Buffer | number[]
|
||||
|
||||
// Using cross-database aliases (recommended)
|
||||
@Column("vector", { length: 1536 })
|
||||
universal_embedding: Buffer | number[]
|
||||
|
||||
@Column("halfvec", { length: 768 })
|
||||
universal_reduced_embedding: Buffer | number[]
|
||||
}
|
||||
```
|
||||
|
||||
By default, the client will return a `Buffer` in the `fvecs`/`hvecs` format, which is more efficient. It is possible to let the driver convert the values to a `number[]` by adding `{ extra: { vectorOutputType: "Array" } }` to the connection options. Check the SAP HANA Client documentation for more information about [REAL_VECTOR](https://help.sap.com/docs/SAP_HANA_CLIENT/f1b440ded6144a54ada97ff95dac7adf/0d197e4389c64e6b9cf90f6f698f62fe.html) or [HALF_VECTOR](https://help.sap.com/docs/SAP_HANA_CLIENT/f1b440ded6144a54ada97ff95dac7adf/8bb854b4ce4a4299bed27c365b717e91.html).
|
||||
|
||||
@ -180,6 +180,67 @@ There are several special column types with additional functionality available:
|
||||
each time you call `save` of entity manager or repository, or during `upsert` operations when an update occurs.
|
||||
You don't need to set this column - it will be automatically set.
|
||||
|
||||
### Vector columns
|
||||
|
||||
Vector columns are supported on both PostgreSQL (via [`pgvector`](https://github.com/pgvector/pgvector) extension) and SAP HANA Cloud, enabling storing and querying vector embeddings for similarity search and machine learning applications.
|
||||
|
||||
TypeORM supports both `vector` and `halfvec` column types across databases:
|
||||
|
||||
- `vector` - stores vectors as 4-byte floats (single precision)
|
||||
- PostgreSQL: native `vector` type via pgvector extension
|
||||
- SAP HANA: alias for `real_vector` type
|
||||
- `halfvec` - stores vectors as 2-byte floats (half precision) for memory efficiency
|
||||
- PostgreSQL: native `halfvec` type via pgvector extension
|
||||
- SAP HANA: alias for `half_vector` type
|
||||
|
||||
You can specify the vector dimensions using the `length` option:
|
||||
|
||||
```typescript
|
||||
@Entity()
|
||||
export class Post {
|
||||
@PrimaryGeneratedColumn()
|
||||
id: number
|
||||
|
||||
// Vector without specified dimensions (works on PostgreSQL and SAP HANA)
|
||||
@Column("vector")
|
||||
embedding: number[] | Buffer
|
||||
|
||||
// Vector with 3 dimensions: vector(3) (works on PostgreSQL and SAP HANA)
|
||||
@Column("vector", { length: 3 })
|
||||
embedding_3d: number[] | Buffer
|
||||
|
||||
// Half-precision vector with 4 dimensions: halfvec(4) (works on PostgreSQL and SAP HANA)
|
||||
@Column("halfvec", { length: 4 })
|
||||
halfvec_embedding: number[] | Buffer
|
||||
}
|
||||
```
|
||||
|
||||
Vector columns can be used for similarity searches using PostgreSQL's vector operators:
|
||||
|
||||
```typescript
|
||||
// L2 distance (Euclidean) - <->
|
||||
const results = await dataSource.query(
|
||||
`SELECT id, embedding FROM post ORDER BY embedding <-> $1 LIMIT 5`,
|
||||
["[1,2,3]"]
|
||||
)
|
||||
|
||||
// Cosine distance - <=>
|
||||
const results = await dataSource.query(
|
||||
`SELECT id, embedding FROM post ORDER BY embedding <=> $1 LIMIT 5`,
|
||||
["[1,2,3]"]
|
||||
)
|
||||
|
||||
// Inner product - <#>
|
||||
const results = await dataSource.query(
|
||||
`SELECT id, embedding FROM post ORDER BY embedding <#> $1 LIMIT 5`,
|
||||
["[1,2,3]"]
|
||||
)
|
||||
```
|
||||
|
||||
> **Note**:
|
||||
> - **PostgreSQL**: Vector columns require the `pgvector` extension to be installed. The extension provides the vector data types and similarity operators.
|
||||
> - **SAP HANA**: Vector columns require SAP HANA Cloud (2024Q1+) and a supported version of `@sap/hana-client`. Use the appropriate [vector similarity functions](https://help.sap.com/docs/hana-cloud-database/sap-hana-cloud-sap-hana-database-sql-reference-guide/vector-functions) for similarity searches.
|
||||
|
||||
## Column types
|
||||
|
||||
TypeORM supports all of the most commonly used database-supported column types.
|
||||
|
||||
@ -191,6 +191,8 @@ export class PostgresDriver implements Driver {
|
||||
"geography",
|
||||
"cube",
|
||||
"ltree",
|
||||
"vector",
|
||||
"halfvec",
|
||||
]
|
||||
|
||||
/**
|
||||
@ -214,6 +216,8 @@ export class PostgresDriver implements Driver {
|
||||
"bit",
|
||||
"varbit",
|
||||
"bit varying",
|
||||
"vector",
|
||||
"halfvec",
|
||||
]
|
||||
|
||||
/**
|
||||
@ -409,6 +413,7 @@ export class PostgresDriver implements Driver {
|
||||
hasCubeColumns,
|
||||
hasGeometryColumns,
|
||||
hasLtreeColumns,
|
||||
hasVectorColumns,
|
||||
hasExclusionConstraints,
|
||||
} = extensionsMetadata
|
||||
|
||||
@ -488,6 +493,18 @@ export class PostgresDriver implements Driver {
|
||||
"At least one of the entities has a ltree column, but the 'ltree' extension cannot be installed automatically. Please install it manually using superuser rights",
|
||||
)
|
||||
}
|
||||
if (hasVectorColumns)
|
||||
try {
|
||||
await this.executeQuery(
|
||||
connection,
|
||||
`CREATE EXTENSION IF NOT EXISTS "vector"`,
|
||||
)
|
||||
} catch (_) {
|
||||
logger.log(
|
||||
"warn",
|
||||
"At least one of the entities has a vector column, but the 'vector' extension (pgvector) cannot be installed automatically. Please install it manually using superuser rights",
|
||||
)
|
||||
}
|
||||
if (hasExclusionConstraints)
|
||||
try {
|
||||
// The btree_gist extension provides operator support in PostgreSQL exclusion constraints
|
||||
@ -556,6 +573,14 @@ export class PostgresDriver implements Driver {
|
||||
)
|
||||
},
|
||||
)
|
||||
const hasVectorColumns = this.connection.entityMetadatas.some(
|
||||
(metadata) => {
|
||||
return metadata.columns.some(
|
||||
(column) =>
|
||||
column.type === "vector" || column.type === "halfvec",
|
||||
)
|
||||
},
|
||||
)
|
||||
const hasExclusionConstraints = this.connection.entityMetadatas.some(
|
||||
(metadata) => {
|
||||
return metadata.exclusions.length > 0
|
||||
@ -569,6 +594,7 @@ export class PostgresDriver implements Driver {
|
||||
hasCubeColumns,
|
||||
hasGeometryColumns,
|
||||
hasLtreeColumns,
|
||||
hasVectorColumns,
|
||||
hasExclusionConstraints,
|
||||
hasExtensions:
|
||||
hasUuidColumns ||
|
||||
@ -577,6 +603,7 @@ export class PostgresDriver implements Driver {
|
||||
hasGeometryColumns ||
|
||||
hasCubeColumns ||
|
||||
hasLtreeColumns ||
|
||||
hasVectorColumns ||
|
||||
hasExclusionConstraints,
|
||||
}
|
||||
}
|
||||
@ -641,6 +668,15 @@ export class PostgresDriver implements Driver {
|
||||
) >= 0
|
||||
) {
|
||||
return JSON.stringify(value)
|
||||
} else if (
|
||||
columnMetadata.type === "vector" ||
|
||||
columnMetadata.type === "halfvec"
|
||||
) {
|
||||
if (Array.isArray(value)) {
|
||||
return `[${value.join(",")}]`
|
||||
} else {
|
||||
return value
|
||||
}
|
||||
} else if (columnMetadata.type === "hstore") {
|
||||
if (typeof value === "string") {
|
||||
return value
|
||||
@ -717,6 +753,18 @@ export class PostgresDriver implements Driver {
|
||||
value = DateUtils.mixedDateToDateString(value)
|
||||
} else if (columnMetadata.type === "time") {
|
||||
value = DateUtils.mixedTimeToString(value)
|
||||
} else if (
|
||||
columnMetadata.type === "vector" ||
|
||||
columnMetadata.type === "halfvec"
|
||||
) {
|
||||
if (
|
||||
typeof value === "string" &&
|
||||
value.startsWith("[") &&
|
||||
value.endsWith("]")
|
||||
) {
|
||||
if (value === "[]") return []
|
||||
return value.slice(1, -1).split(",").map(Number)
|
||||
}
|
||||
} else if (columnMetadata.type === "hstore") {
|
||||
if (columnMetadata.hstoreType === "object") {
|
||||
const unescapeString = (str: string) =>
|
||||
@ -1139,6 +1187,9 @@ export class PostgresDriver implements Driver {
|
||||
} else {
|
||||
type = column.type
|
||||
}
|
||||
} else if (column.type === "vector" || column.type === "halfvec") {
|
||||
type =
|
||||
column.type + (column.length ? "(" + column.length + ")" : "")
|
||||
}
|
||||
|
||||
if (column.isArray) type += " array"
|
||||
|
||||
@ -3491,6 +3491,18 @@ export class PostgresQueryRunner
|
||||
tableColumn.name = dbColumn["column_name"]
|
||||
tableColumn.type = dbColumn["regtype"].toLowerCase()
|
||||
|
||||
if (
|
||||
tableColumn.type === "vector" ||
|
||||
tableColumn.type === "halfvec"
|
||||
) {
|
||||
const lengthMatch = dbColumn[
|
||||
"format_type"
|
||||
].match(/^(?:vector|halfvec)\((\d+)\)$/)
|
||||
if (lengthMatch && lengthMatch[1]) {
|
||||
tableColumn.length = lengthMatch[1]
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
tableColumn.type === "numeric" ||
|
||||
tableColumn.type === "numeric[]" ||
|
||||
|
||||
@ -646,6 +646,10 @@ export class SapDriver implements Driver {
|
||||
return "nclob"
|
||||
} else if (column.type === "simple-enum") {
|
||||
return "nvarchar"
|
||||
} else if (column.type === "vector") {
|
||||
return "real_vector"
|
||||
} else if (column.type === "halfvec") {
|
||||
return "half_vector"
|
||||
}
|
||||
|
||||
if (DriverUtils.isReleaseVersionOrGreater(this, "4.0")) {
|
||||
|
||||
@ -75,6 +75,8 @@ export type WithLengthColumnType =
|
||||
| "binary" // mssql
|
||||
| "varbinary" // mssql, sap
|
||||
| "string" // cockroachdb, spanner
|
||||
| "vector" // postgres, sap
|
||||
| "halfvec" // postgres, sap
|
||||
| "half_vector" // sap
|
||||
| "real_vector" // sap
|
||||
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
import { Entity } from "../../../../../../src/decorator/entity/Entity"
|
||||
import { Column } from "../../../../../../src/decorator/columns/Column"
|
||||
import { PrimaryGeneratedColumn } from "../../../../../../src/decorator/columns/PrimaryGeneratedColumn"
|
||||
|
||||
@Entity()
|
||||
export class Post {
|
||||
@PrimaryGeneratedColumn()
|
||||
id: number
|
||||
|
||||
@Column("vector", { nullable: true })
|
||||
embedding: number[]
|
||||
|
||||
@Column("vector", { length: 3, nullable: true })
|
||||
embedding_three_dimensions: number[]
|
||||
|
||||
@Column("halfvec", { nullable: true })
|
||||
halfvec_embedding: number[]
|
||||
|
||||
@Column("halfvec", { length: 4, nullable: true })
|
||||
halfvec_four_dimensions: number[]
|
||||
}
|
||||
@ -0,0 +1,203 @@
|
||||
import "reflect-metadata"
|
||||
import { expect } from "chai"
|
||||
import { DataSource } from "../../../../../src/data-source/DataSource"
|
||||
import {
|
||||
closeTestingConnections,
|
||||
createTestingConnections,
|
||||
reloadTestingDatabases,
|
||||
} from "../../../../utils/test-utils"
|
||||
import { Post } from "./entity/Post"
|
||||
|
||||
describe("columns > vector type > similarity operations", () => {
|
||||
let connections: DataSource[]
|
||||
before(async () => {
|
||||
connections = await createTestingConnections({
|
||||
entities: [Post],
|
||||
enabledDrivers: ["postgres"],
|
||||
schemaCreate: true,
|
||||
dropSchema: true,
|
||||
})
|
||||
})
|
||||
|
||||
beforeEach(() => reloadTestingDatabases(connections))
|
||||
after(() => closeTestingConnections(connections))
|
||||
|
||||
async function setupTestData(connection: DataSource) {
|
||||
const postRepository = connection.getRepository(Post)
|
||||
await postRepository.clear() // Clear existing data
|
||||
|
||||
// Create test posts with known vectors
|
||||
const posts = await postRepository.save([
|
||||
{ embedding: [1, 1, 1] },
|
||||
{ embedding: [1, 1, 2] },
|
||||
{ embedding: [5, 5, 5] },
|
||||
{ embedding: [2, 2, 2] },
|
||||
{ embedding: [-1, -1, -1] },
|
||||
])
|
||||
|
||||
return posts
|
||||
}
|
||||
|
||||
it("should perform similarity search using L2 distance", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
await setupTestData(connection)
|
||||
const queryVector = "[1,1,1.6]" // Search vector
|
||||
|
||||
const results = await connection.query(
|
||||
`SELECT id, embedding FROM "post" ORDER BY embedding <-> $1 LIMIT 2`,
|
||||
[queryVector],
|
||||
)
|
||||
|
||||
expect(results.length).to.equal(2)
|
||||
// [1,1,2] should be closest to [1,1,1.6], then [1,1,1]
|
||||
expect(results[0].embedding).to.deep.equal("[1,1,2]")
|
||||
expect(results[1].embedding).to.deep.equal("[1,1,1]")
|
||||
}),
|
||||
))
|
||||
|
||||
it("should perform similarity search using cosine distance", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
await setupTestData(connection)
|
||||
const queryVector = "[1,1,1]" // Search vector
|
||||
|
||||
const results = await connection.query(
|
||||
`SELECT id, embedding FROM "post" ORDER BY embedding <=> $1 LIMIT 3`,
|
||||
[queryVector],
|
||||
)
|
||||
|
||||
expect(results.length).to.equal(3)
|
||||
// [1,1,1] and [2,2,2] should have cosine distance 0 (same direction)
|
||||
// [-1,-1,-1] should be last (opposite direction)
|
||||
const embeddings = results.map(
|
||||
(r: { embedding: string }) => r.embedding, // Ensure type is string for raw results
|
||||
)
|
||||
expect(embeddings).to.deep.include.members([
|
||||
"[1,1,1]",
|
||||
"[2,2,2]",
|
||||
])
|
||||
expect(embeddings).to.not.deep.include("[-1,-1,-1]")
|
||||
}),
|
||||
))
|
||||
|
||||
it("should perform similarity search using inner product", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
const postRepository = connection.getRepository(Post)
|
||||
await postRepository.clear()
|
||||
|
||||
// Create vectors with known inner products
|
||||
await postRepository.save([
|
||||
{ embedding: [1, 2, 3] }, // IP with [1,1,1] = 6
|
||||
{ embedding: [3, 3, 3] }, // IP with [1,1,1] = 9
|
||||
{ embedding: [-1, 0, 1] }, // IP with [1,1,1] = 0
|
||||
])
|
||||
|
||||
const queryVector = "[1,1,1]" // Search vector
|
||||
|
||||
const results = await connection.query(
|
||||
`SELECT id, embedding FROM "post" ORDER BY embedding <#> $1 ASC LIMIT 2`, // The <#> operator returns negative inner product, so ASC ordering gives highest positive inner product first (most similar vectors)
|
||||
[queryVector],
|
||||
)
|
||||
|
||||
expect(results.length).to.equal(2)
|
||||
// [3,3,3] should have highest inner product, then [1,2,3]
|
||||
expect(results[0].embedding).to.deep.equal("[3,3,3]")
|
||||
expect(results[1].embedding).to.deep.equal("[1,2,3]")
|
||||
}),
|
||||
))
|
||||
|
||||
it("should prevent persistence of Post with incorrect vector dimensions due to DB constraints", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
const postRepository = connection.getRepository(Post)
|
||||
const post = new Post()
|
||||
post.embedding_three_dimensions = [1, 1] // Wrong dimensions (2 instead of 3)
|
||||
|
||||
let saveThrewError = false
|
||||
try {
|
||||
await postRepository.save(post)
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
} catch (error) {
|
||||
saveThrewError = true
|
||||
}
|
||||
|
||||
expect(saveThrewError).to.be.true
|
||||
expect(post.id).to.be.undefined
|
||||
|
||||
const foundPostWithMalformedEmbedding = await connection
|
||||
.getRepository(Post)
|
||||
.createQueryBuilder("p")
|
||||
.where(
|
||||
"p.embedding_three_dimensions::text = :embeddingText",
|
||||
{
|
||||
embeddingText: "[1,1]",
|
||||
},
|
||||
)
|
||||
.getOne()
|
||||
expect(foundPostWithMalformedEmbedding).to.be.null
|
||||
}),
|
||||
))
|
||||
|
||||
it("should perform halfvec similarity search using L2 distance", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
const postRepository = connection.getRepository(Post)
|
||||
await postRepository.clear()
|
||||
|
||||
// Create test posts with known halfvec values
|
||||
await postRepository.save([
|
||||
{ halfvec_four_dimensions: [1, 1, 1, 1] },
|
||||
{ halfvec_four_dimensions: [1, 1, 2, 2] },
|
||||
{ halfvec_four_dimensions: [5, 5, 5, 5] },
|
||||
{ halfvec_four_dimensions: [2, 2, 2, 2] },
|
||||
])
|
||||
|
||||
const queryVector = "[1,1,1.8,1.8]" // Search vector
|
||||
|
||||
const results = await connection.query(
|
||||
`SELECT id, halfvec_four_dimensions FROM "post" ORDER BY halfvec_four_dimensions <-> $1 LIMIT 2`,
|
||||
[queryVector],
|
||||
)
|
||||
|
||||
expect(results.length).to.equal(2)
|
||||
// [1,1,2,2] should be closest to [1,1,1.8,1.8], then [1,1,1,1]
|
||||
expect(results[0].halfvec_four_dimensions).to.deep.equal(
|
||||
"[1,1,2,2]",
|
||||
)
|
||||
expect(results[1].halfvec_four_dimensions).to.deep.equal(
|
||||
"[1,1,1,1]",
|
||||
)
|
||||
}),
|
||||
))
|
||||
|
||||
it("should prevent persistence of Post with incorrect halfvec dimensions due to DB constraints", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
const postRepository = connection.getRepository(Post)
|
||||
const post = new Post()
|
||||
post.halfvec_four_dimensions = [1, 1, 1] // Wrong dimensions (3 instead of 4)
|
||||
|
||||
let saveThrewError = false
|
||||
try {
|
||||
await postRepository.save(post)
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
} catch (error) {
|
||||
saveThrewError = true
|
||||
}
|
||||
|
||||
expect(saveThrewError).to.be.true
|
||||
expect(post.id).to.be.undefined
|
||||
|
||||
const foundPostWithMalformedHalfvec = await connection
|
||||
.getRepository(Post)
|
||||
.createQueryBuilder("p")
|
||||
.where("p.halfvec_four_dimensions::text = :embeddingText", {
|
||||
embeddingText: "[1,1,1]",
|
||||
})
|
||||
.getOne()
|
||||
expect(foundPostWithMalformedHalfvec).to.be.null
|
||||
}),
|
||||
))
|
||||
})
|
||||
110
test/functional/database-schema/vectors/postgres/vector.ts
Normal file
110
test/functional/database-schema/vectors/postgres/vector.ts
Normal file
@ -0,0 +1,110 @@
|
||||
import "reflect-metadata"
|
||||
import { expect } from "chai"
|
||||
import { DataSource } from "../../../../../src/data-source/DataSource"
|
||||
import {
|
||||
closeTestingConnections,
|
||||
createTestingConnections,
|
||||
reloadTestingDatabases,
|
||||
} from "../../../../utils/test-utils"
|
||||
import { Post } from "./entity/Post"
|
||||
|
||||
describe("columns > vector type", () => {
|
||||
let connections: DataSource[]
|
||||
before(async () => {
|
||||
connections = await createTestingConnections({
|
||||
entities: [Post],
|
||||
enabledDrivers: ["postgres"],
|
||||
schemaCreate: true,
|
||||
dropSchema: true,
|
||||
})
|
||||
})
|
||||
|
||||
beforeEach(() => reloadTestingDatabases(connections))
|
||||
after(() => closeTestingConnections(connections))
|
||||
|
||||
it("should create vector column", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
const postRepository = connection.getRepository(Post)
|
||||
const queryRunner = connection.createQueryRunner()
|
||||
const table = await queryRunner.getTable("post")
|
||||
await queryRunner.release()
|
||||
|
||||
const embedding = [1.0, 2.0, 3.0]
|
||||
const embedding_three_dimensions = [1.0, 2.0, 3.0]
|
||||
const halfvec_embedding = [1.5, 2.5]
|
||||
const halfvec_four_dimensions = [1.5, 2.5, 3.5, 4.5]
|
||||
|
||||
const post = new Post()
|
||||
post.embedding = embedding
|
||||
post.embedding_three_dimensions = embedding_three_dimensions
|
||||
post.halfvec_embedding = halfvec_embedding
|
||||
post.halfvec_four_dimensions = halfvec_four_dimensions
|
||||
|
||||
await postRepository.save(post)
|
||||
|
||||
const loadedPost = await postRepository.findOne({
|
||||
where: { id: post.id },
|
||||
})
|
||||
|
||||
expect(loadedPost).to.exist
|
||||
expect(loadedPost!.embedding).to.deep.equal(embedding)
|
||||
expect(loadedPost!.embedding_three_dimensions).to.deep.equal(
|
||||
embedding_three_dimensions,
|
||||
)
|
||||
expect(loadedPost!.halfvec_embedding).to.deep.equal(
|
||||
halfvec_embedding,
|
||||
)
|
||||
expect(loadedPost!.halfvec_four_dimensions).to.deep.equal(
|
||||
halfvec_four_dimensions,
|
||||
)
|
||||
|
||||
table!
|
||||
.findColumnByName("embedding")!
|
||||
.type.should.be.equal("vector")
|
||||
table!
|
||||
.findColumnByName("embedding_three_dimensions")!
|
||||
.type.should.be.equal("vector")
|
||||
table!
|
||||
.findColumnByName("embedding_three_dimensions")!
|
||||
.length!.should.be.equal("3")
|
||||
table!
|
||||
.findColumnByName("halfvec_embedding")!
|
||||
.type.should.be.equal("halfvec")
|
||||
table!
|
||||
.findColumnByName("halfvec_four_dimensions")!
|
||||
.type.should.be.equal("halfvec")
|
||||
table!
|
||||
.findColumnByName("halfvec_four_dimensions")!
|
||||
.length!.should.be.equal("4")
|
||||
}),
|
||||
))
|
||||
|
||||
it("should update vector values", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
const postRepository = connection.getRepository(Post)
|
||||
|
||||
const post = new Post()
|
||||
post.embedding = [1.0, 2.0]
|
||||
post.embedding_three_dimensions = [3.0, 4.0, 5.0]
|
||||
|
||||
await postRepository.save(post)
|
||||
|
||||
post.embedding = [5.0, 6.0]
|
||||
post.embedding_three_dimensions = [7.0, 8.0, 9.0]
|
||||
|
||||
await postRepository.save(post)
|
||||
|
||||
const loadedPost = await postRepository.findOne({
|
||||
where: { id: post.id },
|
||||
})
|
||||
|
||||
expect(loadedPost).to.exist
|
||||
expect(loadedPost!.embedding).to.deep.equal([5.0, 6.0])
|
||||
expect(loadedPost!.embedding_three_dimensions).to.deep.equal([
|
||||
7.0, 8.0, 9.0,
|
||||
])
|
||||
}),
|
||||
))
|
||||
})
|
||||
@ -0,0 +1,20 @@
|
||||
import { MigrationInterface } from "../../../../src/migration/MigrationInterface"
|
||||
import { QueryRunner } from "../../../../src/query-runner/QueryRunner"
|
||||
|
||||
export class CreatePost0000000000001 implements MigrationInterface {
|
||||
public async up(queryRunner: QueryRunner): Promise<any> {
|
||||
await queryRunner.query(`
|
||||
CREATE TABLE "post" (
|
||||
"id" SERIAL PRIMARY KEY,
|
||||
"embedding" vector,
|
||||
"embedding_three_dimensions" vector(3),
|
||||
"halfvec_embedding" halfvec,
|
||||
"halfvec_four_dimensions" halfvec(4)
|
||||
)
|
||||
`)
|
||||
}
|
||||
|
||||
public async down(queryRunner: QueryRunner): Promise<any> {
|
||||
await queryRunner.query(`DROP TABLE "post"`)
|
||||
}
|
||||
}
|
||||
75
test/functional/migrations/vector/vector-test.ts
Normal file
75
test/functional/migrations/vector/vector-test.ts
Normal file
@ -0,0 +1,75 @@
|
||||
import "reflect-metadata"
|
||||
import { DataSource } from "../../../../src/data-source/DataSource"
|
||||
import {
|
||||
closeTestingConnections,
|
||||
createTestingConnections,
|
||||
reloadTestingDatabases,
|
||||
} from "../../../utils/test-utils"
|
||||
import { CreatePost0000000000001 } from "./0000000000001-CreatePost"
|
||||
|
||||
describe("migrations > vector type", () => {
|
||||
let connections: DataSource[]
|
||||
|
||||
before(async () => {
|
||||
connections = await createTestingConnections({
|
||||
enabledDrivers: ["postgres"],
|
||||
schemaCreate: false,
|
||||
dropSchema: true,
|
||||
migrations: [CreatePost0000000000001],
|
||||
})
|
||||
})
|
||||
|
||||
beforeEach(() => reloadTestingDatabases(connections))
|
||||
after(() => closeTestingConnections(connections))
|
||||
|
||||
it("should run vector migration and create table with vector columns", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
await connection.runMigrations()
|
||||
|
||||
const queryRunner = connection.createQueryRunner()
|
||||
const table = await queryRunner.getTable("post")
|
||||
await queryRunner.release()
|
||||
|
||||
table!
|
||||
.findColumnByName("embedding")!
|
||||
.type.should.be.equal("vector")
|
||||
table!
|
||||
.findColumnByName("embedding_three_dimensions")!
|
||||
.type.should.be.equal("vector")
|
||||
table!
|
||||
.findColumnByName("embedding_three_dimensions")!
|
||||
.length!.should.be.equal("3")
|
||||
table!
|
||||
.findColumnByName("halfvec_embedding")!
|
||||
.type.should.be.equal("halfvec")
|
||||
table!
|
||||
.findColumnByName("halfvec_four_dimensions")!
|
||||
.type.should.be.equal("halfvec")
|
||||
table!
|
||||
.findColumnByName("halfvec_four_dimensions")!
|
||||
.length!.should.be.equal("4")
|
||||
}),
|
||||
))
|
||||
|
||||
it("should handle vector data after migration", () =>
|
||||
Promise.all(
|
||||
connections.map(async (connection) => {
|
||||
await connection.runMigrations()
|
||||
|
||||
const queryRunner = connection.createQueryRunner()
|
||||
await queryRunner.query(
|
||||
'INSERT INTO "post"("embedding", "embedding_three_dimensions", "halfvec_embedding", "halfvec_four_dimensions") VALUES (\'[1,2,3,4]\', \'[4,5,6]\', \'[1.5,2.5]\', \'[1,2,3,4]\')',
|
||||
)
|
||||
|
||||
const result = await queryRunner.query('SELECT * FROM "post"')
|
||||
await queryRunner.release()
|
||||
|
||||
result.length.should.be.equal(1)
|
||||
result[0].embedding.should.equal("[1,2,3,4]")
|
||||
result[0].embedding_three_dimensions.should.equal("[4,5,6]")
|
||||
result[0].halfvec_embedding.should.equal("[1.5,2.5]")
|
||||
result[0].halfvec_four_dimensions.should.equal("[1,2,3,4]")
|
||||
}),
|
||||
))
|
||||
})
|
||||
Loading…
x
Reference in New Issue
Block a user