From ecff60dc8aa0bd1ad5ea8f4623af0756a86dc110 Mon Sep 17 00:00:00 2001 From: Charmander <~@charmander.me> Date: Thu, 11 Dec 2025 09:07:16 -0800 Subject: [PATCH] fix: Avoid retaining buffer for latest parse in reader (#3533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: Add failing test for parser reader cleanup * fix: Avoid retaining buffer for latest parse in reader The buffer can be arbitrarily large, and the parser shouldn’t keep it around while waiting on (and potentially also buffering) the next complete packet. --- .../pg-protocol/src/inbound-parser.test.ts | 7 + packages/pg-protocol/src/parser.ts | 432 +++++++++--------- 2 files changed, 235 insertions(+), 204 deletions(-) diff --git a/packages/pg-protocol/src/inbound-parser.test.ts b/packages/pg-protocol/src/inbound-parser.test.ts index 0575993d..285f4bf2 100644 --- a/packages/pg-protocol/src/inbound-parser.test.ts +++ b/packages/pg-protocol/src/inbound-parser.test.ts @@ -4,6 +4,7 @@ import { parse } from '.' import assert from 'assert' import { PassThrough } from 'stream' import { BackendMessage } from './messages' +import { Parser } from './parser' const authOkBuffer = buffers.authenticationOk() const paramStatusBuffer = buffers.parameterStatus('client_encoding', 'UTF8') @@ -565,4 +566,10 @@ describe('PgPacketStream', function () { }) }) }) + + it('cleans up the reader after handling a packet', function () { + const parser = new Parser() + parser.parse(oneFieldBuf, () => {}) + assert.strictEqual((parser as any).reader.buffer.byteLength, 0) + }) }) diff --git a/packages/pg-protocol/src/parser.ts b/packages/pg-protocol/src/parser.ts index f7313f23..998077a0 100644 --- a/packages/pg-protocol/src/parser.ts +++ b/packages/pg-protocol/src/parser.ts @@ -36,6 +36,9 @@ const LEN_LENGTH = 4 const HEADER_LENGTH = CODE_LENGTH + LEN_LENGTH +// A placeholder for a `BackendMessage`’s length value that will be set after construction. +const LATEINIT_LENGTH = -1 + export type Packet = { code: number packet: Buffer @@ -152,238 +155,259 @@ export class Parser { } private handlePacket(offset: number, code: number, length: number, bytes: Buffer): BackendMessage { + const { reader } = this + + // NOTE: This undesirably retains the buffer in `this.reader` if the `parse*Message` calls below throw. However, those should only throw in the case of a protocol error, which normally results in the reader being discarded. + reader.setBuffer(offset, bytes) + + let message: BackendMessage + switch (code) { case MessageCodes.BindComplete: - return bindComplete + message = bindComplete + break case MessageCodes.ParseComplete: - return parseComplete + message = parseComplete + break case MessageCodes.CloseComplete: - return closeComplete + message = closeComplete + break case MessageCodes.NoData: - return noData + message = noData + break case MessageCodes.PortalSuspended: - return portalSuspended + message = portalSuspended + break case MessageCodes.CopyDone: - return copyDone + message = copyDone + break case MessageCodes.ReplicationStart: - return replicationStart + message = replicationStart + break case MessageCodes.EmptyQuery: - return emptyQuery + message = emptyQuery + break case MessageCodes.DataRow: - return this.parseDataRowMessage(offset, length, bytes) + message = parseDataRowMessage(reader) + break case MessageCodes.CommandComplete: - return this.parseCommandCompleteMessage(offset, length, bytes) + message = parseCommandCompleteMessage(reader) + break case MessageCodes.ReadyForQuery: - return this.parseReadyForQueryMessage(offset, length, bytes) + message = parseReadyForQueryMessage(reader) + break case MessageCodes.NotificationResponse: - return this.parseNotificationMessage(offset, length, bytes) + message = parseNotificationMessage(reader) + break case MessageCodes.AuthenticationResponse: - return this.parseAuthenticationResponse(offset, length, bytes) + message = parseAuthenticationResponse(reader, length) + break case MessageCodes.ParameterStatus: - return this.parseParameterStatusMessage(offset, length, bytes) + message = parseParameterStatusMessage(reader) + break case MessageCodes.BackendKeyData: - return this.parseBackendKeyData(offset, length, bytes) + message = parseBackendKeyData(reader) + break case MessageCodes.ErrorMessage: - return this.parseErrorMessage(offset, length, bytes, 'error') + message = parseErrorMessage(reader, 'error') + break case MessageCodes.NoticeMessage: - return this.parseErrorMessage(offset, length, bytes, 'notice') + message = parseErrorMessage(reader, 'notice') + break case MessageCodes.RowDescriptionMessage: - return this.parseRowDescriptionMessage(offset, length, bytes) + message = parseRowDescriptionMessage(reader) + break case MessageCodes.ParameterDescriptionMessage: - return this.parseParameterDescriptionMessage(offset, length, bytes) + message = parseParameterDescriptionMessage(reader) + break case MessageCodes.CopyIn: - return this.parseCopyInMessage(offset, length, bytes) + message = parseCopyInMessage(reader) + break case MessageCodes.CopyOut: - return this.parseCopyOutMessage(offset, length, bytes) + message = parseCopyOutMessage(reader) + break case MessageCodes.CopyData: - return this.parseCopyData(offset, length, bytes) + message = parseCopyData(reader, length) + break default: return new DatabaseError('received invalid response: ' + code.toString(16), length, 'error') } - } - private parseReadyForQueryMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const status = this.reader.string(1) - return new ReadyForQueryMessage(length, status) - } + reader.setBuffer(0, emptyBuffer) - private parseCommandCompleteMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const text = this.reader.cstring() - return new CommandCompleteMessage(length, text) - } - - private parseCopyData(offset: number, length: number, bytes: Buffer) { - const chunk = bytes.slice(offset, offset + (length - 4)) - return new CopyDataMessage(length, chunk) - } - - private parseCopyInMessage(offset: number, length: number, bytes: Buffer) { - return this.parseCopyMessage(offset, length, bytes, 'copyInResponse') - } - - private parseCopyOutMessage(offset: number, length: number, bytes: Buffer) { - return this.parseCopyMessage(offset, length, bytes, 'copyOutResponse') - } - - private parseCopyMessage(offset: number, length: number, bytes: Buffer, messageName: MessageName) { - this.reader.setBuffer(offset, bytes) - const isBinary = this.reader.byte() !== 0 - const columnCount = this.reader.int16() - const message = new CopyResponse(length, messageName, isBinary, columnCount) - for (let i = 0; i < columnCount; i++) { - message.columnTypes[i] = this.reader.int16() - } - return message - } - - private parseNotificationMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const processId = this.reader.int32() - const channel = this.reader.cstring() - const payload = this.reader.cstring() - return new NotificationResponseMessage(length, processId, channel, payload) - } - - private parseRowDescriptionMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const fieldCount = this.reader.int16() - const message = new RowDescriptionMessage(length, fieldCount) - for (let i = 0; i < fieldCount; i++) { - message.fields[i] = this.parseField() - } - return message - } - - private parseField(): Field { - const name = this.reader.cstring() - const tableID = this.reader.uint32() - const columnID = this.reader.int16() - const dataTypeID = this.reader.uint32() - const dataTypeSize = this.reader.int16() - const dataTypeModifier = this.reader.int32() - const mode = this.reader.int16() === 0 ? 'text' : 'binary' - return new Field(name, tableID, columnID, dataTypeID, dataTypeSize, dataTypeModifier, mode) - } - - private parseParameterDescriptionMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const parameterCount = this.reader.int16() - const message = new ParameterDescriptionMessage(length, parameterCount) - for (let i = 0; i < parameterCount; i++) { - message.dataTypeIDs[i] = this.reader.int32() - } - return message - } - - private parseDataRowMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const fieldCount = this.reader.int16() - const fields: any[] = new Array(fieldCount) - for (let i = 0; i < fieldCount; i++) { - const len = this.reader.int32() - // a -1 for length means the value of the field is null - fields[i] = len === -1 ? null : this.reader.string(len) - } - return new DataRowMessage(length, fields) - } - - private parseParameterStatusMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const name = this.reader.cstring() - const value = this.reader.cstring() - return new ParameterStatusMessage(length, name, value) - } - - private parseBackendKeyData(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const processID = this.reader.int32() - const secretKey = this.reader.int32() - return new BackendKeyDataMessage(length, processID, secretKey) - } - - public parseAuthenticationResponse(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const code = this.reader.int32() - // TODO(bmc): maybe better types here - const message: BackendMessage & any = { - name: 'authenticationOk', - length, - } - - switch (code) { - case 0: // AuthenticationOk - break - case 3: // AuthenticationCleartextPassword - if (message.length === 8) { - message.name = 'authenticationCleartextPassword' - } - break - case 5: // AuthenticationMD5Password - if (message.length === 12) { - message.name = 'authenticationMD5Password' - const salt = this.reader.bytes(4) - return new AuthenticationMD5Password(length, salt) - } - break - case 10: // AuthenticationSASL - { - message.name = 'authenticationSASL' - message.mechanisms = [] - let mechanism: string - do { - mechanism = this.reader.cstring() - if (mechanism) { - message.mechanisms.push(mechanism) - } - } while (mechanism) - } - break - case 11: // AuthenticationSASLContinue - message.name = 'authenticationSASLContinue' - message.data = this.reader.string(length - 8) - break - case 12: // AuthenticationSASLFinal - message.name = 'authenticationSASLFinal' - message.data = this.reader.string(length - 8) - break - default: - throw new Error('Unknown authenticationOk message type ' + code) - } - return message - } - - private parseErrorMessage(offset: number, length: number, bytes: Buffer, name: MessageName) { - this.reader.setBuffer(offset, bytes) - const fields: Record = {} - let fieldType = this.reader.string(1) - while (fieldType !== '\0') { - fields[fieldType] = this.reader.cstring() - fieldType = this.reader.string(1) - } - - const messageValue = fields.M - - const message = - name === 'notice' ? new NoticeMessage(length, messageValue) : new DatabaseError(messageValue, length, name) - - message.severity = fields.S - message.code = fields.C - message.detail = fields.D - message.hint = fields.H - message.position = fields.P - message.internalPosition = fields.p - message.internalQuery = fields.q - message.where = fields.W - message.schema = fields.s - message.table = fields.t - message.column = fields.c - message.dataType = fields.d - message.constraint = fields.n - message.file = fields.F - message.line = fields.L - message.routine = fields.R + message.length = length return message } } + +const parseReadyForQueryMessage = (reader: BufferReader) => { + const status = reader.string(1) + return new ReadyForQueryMessage(LATEINIT_LENGTH, status) +} + +const parseCommandCompleteMessage = (reader: BufferReader) => { + const text = reader.cstring() + return new CommandCompleteMessage(LATEINIT_LENGTH, text) +} + +const parseCopyData = (reader: BufferReader, length: number) => { + const chunk = reader.bytes(length - 4) + return new CopyDataMessage(LATEINIT_LENGTH, chunk) +} + +const parseCopyInMessage = (reader: BufferReader) => parseCopyMessage(reader, 'copyInResponse') + +const parseCopyOutMessage = (reader: BufferReader) => parseCopyMessage(reader, 'copyOutResponse') + +const parseCopyMessage = (reader: BufferReader, messageName: MessageName) => { + const isBinary = reader.byte() !== 0 + const columnCount = reader.int16() + const message = new CopyResponse(LATEINIT_LENGTH, messageName, isBinary, columnCount) + for (let i = 0; i < columnCount; i++) { + message.columnTypes[i] = reader.int16() + } + return message +} + +const parseNotificationMessage = (reader: BufferReader) => { + const processId = reader.int32() + const channel = reader.cstring() + const payload = reader.cstring() + return new NotificationResponseMessage(LATEINIT_LENGTH, processId, channel, payload) +} + +const parseRowDescriptionMessage = (reader: BufferReader) => { + const fieldCount = reader.int16() + const message = new RowDescriptionMessage(LATEINIT_LENGTH, fieldCount) + for (let i = 0; i < fieldCount; i++) { + message.fields[i] = parseField(reader) + } + return message +} + +const parseField = (reader: BufferReader) => { + const name = reader.cstring() + const tableID = reader.uint32() + const columnID = reader.int16() + const dataTypeID = reader.uint32() + const dataTypeSize = reader.int16() + const dataTypeModifier = reader.int32() + const mode = reader.int16() === 0 ? 'text' : 'binary' + return new Field(name, tableID, columnID, dataTypeID, dataTypeSize, dataTypeModifier, mode) +} + +const parseParameterDescriptionMessage = (reader: BufferReader) => { + const parameterCount = reader.int16() + const message = new ParameterDescriptionMessage(LATEINIT_LENGTH, parameterCount) + for (let i = 0; i < parameterCount; i++) { + message.dataTypeIDs[i] = reader.int32() + } + return message +} + +const parseDataRowMessage = (reader: BufferReader) => { + const fieldCount = reader.int16() + const fields: any[] = new Array(fieldCount) + for (let i = 0; i < fieldCount; i++) { + const len = reader.int32() + // a -1 for length means the value of the field is null + fields[i] = len === -1 ? null : reader.string(len) + } + return new DataRowMessage(LATEINIT_LENGTH, fields) +} + +const parseParameterStatusMessage = (reader: BufferReader) => { + const name = reader.cstring() + const value = reader.cstring() + return new ParameterStatusMessage(LATEINIT_LENGTH, name, value) +} + +const parseBackendKeyData = (reader: BufferReader) => { + const processID = reader.int32() + const secretKey = reader.int32() + return new BackendKeyDataMessage(LATEINIT_LENGTH, processID, secretKey) +} + +const parseAuthenticationResponse = (reader: BufferReader, length: number) => { + const code = reader.int32() + // TODO(bmc): maybe better types here + const message: BackendMessage & any = { + name: 'authenticationOk', + length, + } + + switch (code) { + case 0: // AuthenticationOk + break + case 3: // AuthenticationCleartextPassword + if (message.length === 8) { + message.name = 'authenticationCleartextPassword' + } + break + case 5: // AuthenticationMD5Password + if (message.length === 12) { + message.name = 'authenticationMD5Password' + const salt = reader.bytes(4) + return new AuthenticationMD5Password(LATEINIT_LENGTH, salt) + } + break + case 10: // AuthenticationSASL + { + message.name = 'authenticationSASL' + message.mechanisms = [] + let mechanism: string + do { + mechanism = reader.cstring() + if (mechanism) { + message.mechanisms.push(mechanism) + } + } while (mechanism) + } + break + case 11: // AuthenticationSASLContinue + message.name = 'authenticationSASLContinue' + message.data = reader.string(length - 8) + break + case 12: // AuthenticationSASLFinal + message.name = 'authenticationSASLFinal' + message.data = reader.string(length - 8) + break + default: + throw new Error('Unknown authenticationOk message type ' + code) + } + return message +} + +const parseErrorMessage = (reader: BufferReader, name: MessageName) => { + const fields: Record = {} + let fieldType = reader.string(1) + while (fieldType !== '\0') { + fields[fieldType] = reader.cstring() + fieldType = reader.string(1) + } + + const messageValue = fields.M + + const message = + name === 'notice' + ? new NoticeMessage(LATEINIT_LENGTH, messageValue) + : new DatabaseError(messageValue, LATEINIT_LENGTH, name) + + message.severity = fields.S + message.code = fields.C + message.detail = fields.D + message.hint = fields.H + message.position = fields.P + message.internalPosition = fields.p + message.internalQuery = fields.q + message.where = fields.W + message.schema = fields.s + message.table = fields.t + message.column = fields.c + message.dataType = fields.d + message.constraint = fields.n + message.file = fields.F + message.line = fields.L + message.routine = fields.R + return message +}