fix: Avoid retaining buffer for latest parse in reader (#3533)

* test: Add failing test for parser reader cleanup

* fix: Avoid retaining buffer for latest parse in reader

The buffer can be arbitrarily large, and the parser shouldn’t keep it around while waiting on (and potentially also buffering) the next complete packet.
This commit is contained in:
Charmander 2025-12-11 09:07:16 -08:00 committed by GitHub
parent 8d493f3b55
commit ecff60dc8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 235 additions and 204 deletions

View File

@ -4,6 +4,7 @@ import { parse } from '.'
import assert from 'assert'
import { PassThrough } from 'stream'
import { BackendMessage } from './messages'
import { Parser } from './parser'
const authOkBuffer = buffers.authenticationOk()
const paramStatusBuffer = buffers.parameterStatus('client_encoding', 'UTF8')
@ -565,4 +566,10 @@ describe('PgPacketStream', function () {
})
})
})
it('cleans up the reader after handling a packet', function () {
const parser = new Parser()
parser.parse(oneFieldBuf, () => {})
assert.strictEqual((parser as any).reader.buffer.byteLength, 0)
})
})

View File

@ -36,6 +36,9 @@ const LEN_LENGTH = 4
const HEADER_LENGTH = CODE_LENGTH + LEN_LENGTH
// A placeholder for a `BackendMessage`s length value that will be set after construction.
const LATEINIT_LENGTH = -1
export type Packet = {
code: number
packet: Buffer
@ -152,238 +155,259 @@ export class Parser {
}
private handlePacket(offset: number, code: number, length: number, bytes: Buffer): BackendMessage {
const { reader } = this
// NOTE: This undesirably retains the buffer in `this.reader` if the `parse*Message` calls below throw. However, those should only throw in the case of a protocol error, which normally results in the reader being discarded.
reader.setBuffer(offset, bytes)
let message: BackendMessage
switch (code) {
case MessageCodes.BindComplete:
return bindComplete
message = bindComplete
break
case MessageCodes.ParseComplete:
return parseComplete
message = parseComplete
break
case MessageCodes.CloseComplete:
return closeComplete
message = closeComplete
break
case MessageCodes.NoData:
return noData
message = noData
break
case MessageCodes.PortalSuspended:
return portalSuspended
message = portalSuspended
break
case MessageCodes.CopyDone:
return copyDone
message = copyDone
break
case MessageCodes.ReplicationStart:
return replicationStart
message = replicationStart
break
case MessageCodes.EmptyQuery:
return emptyQuery
message = emptyQuery
break
case MessageCodes.DataRow:
return this.parseDataRowMessage(offset, length, bytes)
message = parseDataRowMessage(reader)
break
case MessageCodes.CommandComplete:
return this.parseCommandCompleteMessage(offset, length, bytes)
message = parseCommandCompleteMessage(reader)
break
case MessageCodes.ReadyForQuery:
return this.parseReadyForQueryMessage(offset, length, bytes)
message = parseReadyForQueryMessage(reader)
break
case MessageCodes.NotificationResponse:
return this.parseNotificationMessage(offset, length, bytes)
message = parseNotificationMessage(reader)
break
case MessageCodes.AuthenticationResponse:
return this.parseAuthenticationResponse(offset, length, bytes)
message = parseAuthenticationResponse(reader, length)
break
case MessageCodes.ParameterStatus:
return this.parseParameterStatusMessage(offset, length, bytes)
message = parseParameterStatusMessage(reader)
break
case MessageCodes.BackendKeyData:
return this.parseBackendKeyData(offset, length, bytes)
message = parseBackendKeyData(reader)
break
case MessageCodes.ErrorMessage:
return this.parseErrorMessage(offset, length, bytes, 'error')
message = parseErrorMessage(reader, 'error')
break
case MessageCodes.NoticeMessage:
return this.parseErrorMessage(offset, length, bytes, 'notice')
message = parseErrorMessage(reader, 'notice')
break
case MessageCodes.RowDescriptionMessage:
return this.parseRowDescriptionMessage(offset, length, bytes)
message = parseRowDescriptionMessage(reader)
break
case MessageCodes.ParameterDescriptionMessage:
return this.parseParameterDescriptionMessage(offset, length, bytes)
message = parseParameterDescriptionMessage(reader)
break
case MessageCodes.CopyIn:
return this.parseCopyInMessage(offset, length, bytes)
message = parseCopyInMessage(reader)
break
case MessageCodes.CopyOut:
return this.parseCopyOutMessage(offset, length, bytes)
message = parseCopyOutMessage(reader)
break
case MessageCodes.CopyData:
return this.parseCopyData(offset, length, bytes)
message = parseCopyData(reader, length)
break
default:
return new DatabaseError('received invalid response: ' + code.toString(16), length, 'error')
}
}
private parseReadyForQueryMessage(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const status = this.reader.string(1)
return new ReadyForQueryMessage(length, status)
}
reader.setBuffer(0, emptyBuffer)
private parseCommandCompleteMessage(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const text = this.reader.cstring()
return new CommandCompleteMessage(length, text)
}
private parseCopyData(offset: number, length: number, bytes: Buffer) {
const chunk = bytes.slice(offset, offset + (length - 4))
return new CopyDataMessage(length, chunk)
}
private parseCopyInMessage(offset: number, length: number, bytes: Buffer) {
return this.parseCopyMessage(offset, length, bytes, 'copyInResponse')
}
private parseCopyOutMessage(offset: number, length: number, bytes: Buffer) {
return this.parseCopyMessage(offset, length, bytes, 'copyOutResponse')
}
private parseCopyMessage(offset: number, length: number, bytes: Buffer, messageName: MessageName) {
this.reader.setBuffer(offset, bytes)
const isBinary = this.reader.byte() !== 0
const columnCount = this.reader.int16()
const message = new CopyResponse(length, messageName, isBinary, columnCount)
for (let i = 0; i < columnCount; i++) {
message.columnTypes[i] = this.reader.int16()
}
return message
}
private parseNotificationMessage(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const processId = this.reader.int32()
const channel = this.reader.cstring()
const payload = this.reader.cstring()
return new NotificationResponseMessage(length, processId, channel, payload)
}
private parseRowDescriptionMessage(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const fieldCount = this.reader.int16()
const message = new RowDescriptionMessage(length, fieldCount)
for (let i = 0; i < fieldCount; i++) {
message.fields[i] = this.parseField()
}
return message
}
private parseField(): Field {
const name = this.reader.cstring()
const tableID = this.reader.uint32()
const columnID = this.reader.int16()
const dataTypeID = this.reader.uint32()
const dataTypeSize = this.reader.int16()
const dataTypeModifier = this.reader.int32()
const mode = this.reader.int16() === 0 ? 'text' : 'binary'
return new Field(name, tableID, columnID, dataTypeID, dataTypeSize, dataTypeModifier, mode)
}
private parseParameterDescriptionMessage(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const parameterCount = this.reader.int16()
const message = new ParameterDescriptionMessage(length, parameterCount)
for (let i = 0; i < parameterCount; i++) {
message.dataTypeIDs[i] = this.reader.int32()
}
return message
}
private parseDataRowMessage(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const fieldCount = this.reader.int16()
const fields: any[] = new Array(fieldCount)
for (let i = 0; i < fieldCount; i++) {
const len = this.reader.int32()
// a -1 for length means the value of the field is null
fields[i] = len === -1 ? null : this.reader.string(len)
}
return new DataRowMessage(length, fields)
}
private parseParameterStatusMessage(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const name = this.reader.cstring()
const value = this.reader.cstring()
return new ParameterStatusMessage(length, name, value)
}
private parseBackendKeyData(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const processID = this.reader.int32()
const secretKey = this.reader.int32()
return new BackendKeyDataMessage(length, processID, secretKey)
}
public parseAuthenticationResponse(offset: number, length: number, bytes: Buffer) {
this.reader.setBuffer(offset, bytes)
const code = this.reader.int32()
// TODO(bmc): maybe better types here
const message: BackendMessage & any = {
name: 'authenticationOk',
length,
}
switch (code) {
case 0: // AuthenticationOk
break
case 3: // AuthenticationCleartextPassword
if (message.length === 8) {
message.name = 'authenticationCleartextPassword'
}
break
case 5: // AuthenticationMD5Password
if (message.length === 12) {
message.name = 'authenticationMD5Password'
const salt = this.reader.bytes(4)
return new AuthenticationMD5Password(length, salt)
}
break
case 10: // AuthenticationSASL
{
message.name = 'authenticationSASL'
message.mechanisms = []
let mechanism: string
do {
mechanism = this.reader.cstring()
if (mechanism) {
message.mechanisms.push(mechanism)
}
} while (mechanism)
}
break
case 11: // AuthenticationSASLContinue
message.name = 'authenticationSASLContinue'
message.data = this.reader.string(length - 8)
break
case 12: // AuthenticationSASLFinal
message.name = 'authenticationSASLFinal'
message.data = this.reader.string(length - 8)
break
default:
throw new Error('Unknown authenticationOk message type ' + code)
}
return message
}
private parseErrorMessage(offset: number, length: number, bytes: Buffer, name: MessageName) {
this.reader.setBuffer(offset, bytes)
const fields: Record<string, string> = {}
let fieldType = this.reader.string(1)
while (fieldType !== '\0') {
fields[fieldType] = this.reader.cstring()
fieldType = this.reader.string(1)
}
const messageValue = fields.M
const message =
name === 'notice' ? new NoticeMessage(length, messageValue) : new DatabaseError(messageValue, length, name)
message.severity = fields.S
message.code = fields.C
message.detail = fields.D
message.hint = fields.H
message.position = fields.P
message.internalPosition = fields.p
message.internalQuery = fields.q
message.where = fields.W
message.schema = fields.s
message.table = fields.t
message.column = fields.c
message.dataType = fields.d
message.constraint = fields.n
message.file = fields.F
message.line = fields.L
message.routine = fields.R
message.length = length
return message
}
}
const parseReadyForQueryMessage = (reader: BufferReader) => {
const status = reader.string(1)
return new ReadyForQueryMessage(LATEINIT_LENGTH, status)
}
const parseCommandCompleteMessage = (reader: BufferReader) => {
const text = reader.cstring()
return new CommandCompleteMessage(LATEINIT_LENGTH, text)
}
const parseCopyData = (reader: BufferReader, length: number) => {
const chunk = reader.bytes(length - 4)
return new CopyDataMessage(LATEINIT_LENGTH, chunk)
}
const parseCopyInMessage = (reader: BufferReader) => parseCopyMessage(reader, 'copyInResponse')
const parseCopyOutMessage = (reader: BufferReader) => parseCopyMessage(reader, 'copyOutResponse')
const parseCopyMessage = (reader: BufferReader, messageName: MessageName) => {
const isBinary = reader.byte() !== 0
const columnCount = reader.int16()
const message = new CopyResponse(LATEINIT_LENGTH, messageName, isBinary, columnCount)
for (let i = 0; i < columnCount; i++) {
message.columnTypes[i] = reader.int16()
}
return message
}
const parseNotificationMessage = (reader: BufferReader) => {
const processId = reader.int32()
const channel = reader.cstring()
const payload = reader.cstring()
return new NotificationResponseMessage(LATEINIT_LENGTH, processId, channel, payload)
}
const parseRowDescriptionMessage = (reader: BufferReader) => {
const fieldCount = reader.int16()
const message = new RowDescriptionMessage(LATEINIT_LENGTH, fieldCount)
for (let i = 0; i < fieldCount; i++) {
message.fields[i] = parseField(reader)
}
return message
}
const parseField = (reader: BufferReader) => {
const name = reader.cstring()
const tableID = reader.uint32()
const columnID = reader.int16()
const dataTypeID = reader.uint32()
const dataTypeSize = reader.int16()
const dataTypeModifier = reader.int32()
const mode = reader.int16() === 0 ? 'text' : 'binary'
return new Field(name, tableID, columnID, dataTypeID, dataTypeSize, dataTypeModifier, mode)
}
const parseParameterDescriptionMessage = (reader: BufferReader) => {
const parameterCount = reader.int16()
const message = new ParameterDescriptionMessage(LATEINIT_LENGTH, parameterCount)
for (let i = 0; i < parameterCount; i++) {
message.dataTypeIDs[i] = reader.int32()
}
return message
}
const parseDataRowMessage = (reader: BufferReader) => {
const fieldCount = reader.int16()
const fields: any[] = new Array(fieldCount)
for (let i = 0; i < fieldCount; i++) {
const len = reader.int32()
// a -1 for length means the value of the field is null
fields[i] = len === -1 ? null : reader.string(len)
}
return new DataRowMessage(LATEINIT_LENGTH, fields)
}
const parseParameterStatusMessage = (reader: BufferReader) => {
const name = reader.cstring()
const value = reader.cstring()
return new ParameterStatusMessage(LATEINIT_LENGTH, name, value)
}
const parseBackendKeyData = (reader: BufferReader) => {
const processID = reader.int32()
const secretKey = reader.int32()
return new BackendKeyDataMessage(LATEINIT_LENGTH, processID, secretKey)
}
const parseAuthenticationResponse = (reader: BufferReader, length: number) => {
const code = reader.int32()
// TODO(bmc): maybe better types here
const message: BackendMessage & any = {
name: 'authenticationOk',
length,
}
switch (code) {
case 0: // AuthenticationOk
break
case 3: // AuthenticationCleartextPassword
if (message.length === 8) {
message.name = 'authenticationCleartextPassword'
}
break
case 5: // AuthenticationMD5Password
if (message.length === 12) {
message.name = 'authenticationMD5Password'
const salt = reader.bytes(4)
return new AuthenticationMD5Password(LATEINIT_LENGTH, salt)
}
break
case 10: // AuthenticationSASL
{
message.name = 'authenticationSASL'
message.mechanisms = []
let mechanism: string
do {
mechanism = reader.cstring()
if (mechanism) {
message.mechanisms.push(mechanism)
}
} while (mechanism)
}
break
case 11: // AuthenticationSASLContinue
message.name = 'authenticationSASLContinue'
message.data = reader.string(length - 8)
break
case 12: // AuthenticationSASLFinal
message.name = 'authenticationSASLFinal'
message.data = reader.string(length - 8)
break
default:
throw new Error('Unknown authenticationOk message type ' + code)
}
return message
}
const parseErrorMessage = (reader: BufferReader, name: MessageName) => {
const fields: Record<string, string> = {}
let fieldType = reader.string(1)
while (fieldType !== '\0') {
fields[fieldType] = reader.cstring()
fieldType = reader.string(1)
}
const messageValue = fields.M
const message =
name === 'notice'
? new NoticeMessage(LATEINIT_LENGTH, messageValue)
: new DatabaseError(messageValue, LATEINIT_LENGTH, name)
message.severity = fields.S
message.code = fields.C
message.detail = fields.D
message.hint = fields.H
message.position = fields.P
message.internalPosition = fields.p
message.internalQuery = fields.q
message.where = fields.W
message.schema = fields.s
message.table = fields.t
message.column = fields.c
message.dataType = fields.d
message.constraint = fields.n
message.file = fields.F
message.line = fields.L
message.routine = fields.R
return message
}