Skip to content

WIP 🔑feat: Implement End-to-End Encryption E2EE Across Messaging #5906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
107 changes: 97 additions & 10 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
const crypto = require('crypto');
const fetch = require('node-fetch');
const {
supportsBalanceCheck,
Expand All @@ -9,14 +8,56 @@
ErrorTypes,
Constants,
} = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo, getUserById } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const { truncateToolCallOutputs } = require('./prompts');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
const TextStream = require('./TextStream');
const { logger } = require('~/config');

let crypto;
try {
crypto = require('crypto');
} catch (err) {
logger.error('[AskController] crypto support is disabled!', err);
}

/**
* Helper function to encrypt plaintext using AES-256-GCM and then RSA-encrypt the AES key.
* @param {string} plainText - The plaintext to encrypt.
* @param {string} pemPublicKey - The RSA public key in PEM format.
* @returns {Object} An object containing the ciphertext, iv, authTag, and encryptedKey.
*/
function encryptText(plainText, pemPublicKey) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move this encryption function outside of this module to where the other encryption methods live. also it seems duplicated elsewhere

// Generate a random 256-bit AES key and a 12-byte IV.
const aesKey = crypto.randomBytes(32);
const iv = crypto.randomBytes(12);

// Encrypt the plaintext using AES-256-GCM.
const cipher = crypto.createCipheriv('aes-256-gcm', aesKey, iv);
let ciphertext = cipher.update(plainText, 'utf8', 'base64');
ciphertext += cipher.final('base64');
const authTag = cipher.getAuthTag().toString('base64');

// Encrypt the AES key using the user's RSA public key.
const encryptedKey = crypto.publicEncrypt(
{
key: pemPublicKey,
padding: crypto.constants.RSA_PKCS1_OAEP_PADDING,
oaepHash: 'sha256',
},
aesKey,
).toString('base64');

return {
ciphertext,
iv: iv.toString('base64'),
authTag,
encryptedKey,
};
}

class BaseClient {
constructor(apiKey, options = {}) {
this.apiKey = apiKey;
Expand Down Expand Up @@ -849,18 +890,64 @@
* @param {string | null} user
*/
async saveMessageToDatabase(message, endpointOptions, user = null) {
if (this.user && user !== this.user) {
// Normalize the user information:
// If "user" is an object, use it; otherwise, if a string is passed use req.user (if available)
const currentUser =
user && typeof user === 'object'
? user
: (this.options.req && this.options.req.user
? this.options.req.user
: { id: user });
Comment on lines +896 to +900

Check warning

Code scanning / ESLint

Disallow nested ternary expressions Warning

Do not nest ternary expressions.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no nested ternaries please

const currentUserId = currentUser.id || currentUser;

// Check if the client’s stored user matches the current user.
// (this.user might have been set earlier in setMessageOptions)
const storedUserId =
this.user && typeof this.user === 'object' ? this.user.id : this.user;
if (storedUserId && currentUserId && storedUserId !== currentUserId) {
throw new Error('User mismatch.');
}

// console.log('User ID:', currentUserId);

const dbUser = await getUserById(currentUserId, 'encryptionPublicKey');

// --- NEW ENCRYPTION BLOCK: Encrypt AI response if encryptionPublicKey exists ---
if (dbUser.encryptionPublicKey && message && message.text) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if possible, consolidate the encryption logic to its own function and outside of this module

try {
// Rebuild the PEM format if necessary.
const pemPublicKey = `-----BEGIN PUBLIC KEY-----\n${dbUser.encryptionPublicKey
.match(/.{1,64}/g)
.join('\n')}\n-----END PUBLIC KEY-----`;
const { ciphertext, iv, authTag, encryptedKey } = encryptText(
message.text,
pemPublicKey,
);
message.text = ciphertext;
message.iv = iv;
message.authTag = authTag;
message.encryptedKey = encryptedKey;
logger.debug('[BaseClient.saveMessageToDatabase] Encrypted message text');
} catch (err) {
logger.error('[BaseClient.saveMessageToDatabase] Error encrypting message text', err);
}
}
// --- End Encryption Block ---

// Build update parameters including encryption fields.
const updateParams = {
...message,
endpoint: this.options.endpoint,
unfinished: false,
user: currentUserId, // store the user id (ensured to be a string)
iv: message.iv ?? null,
authTag: message.authTag ?? null,
encryptedKey: message.encryptedKey ?? null,
};

const savedMessage = await saveMessage(
this.options.req,
{
...message,
endpoint: this.options.endpoint,
unfinished: false,
user,
},
updateParams,
{ context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveMessage' },
);

Expand Down Expand Up @@ -1149,4 +1236,4 @@
}
}

module.exports = BaseClient;
module.exports = BaseClient;
49 changes: 28 additions & 21 deletions api/models/Message.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const { z } = require('zod');
const Message = require('./schema/messageSchema');
const { logger } = require('~/config');

// Validate conversation ID as a UUID (if your conversation IDs follow UUID format)
const idSchema = z.string().uuid();

/**
Expand All @@ -28,8 +29,11 @@ const idSchema = z.string().uuid();
* @param {string} [params.plugin] - Plugin associated with the message.
* @param {string[]} [params.plugins] - An array of plugins associated with the message.
* @param {string} [params.model] - The model used to generate the message.
* @param {Object} [metadata] - Additional metadata for this operation
* @param {string} [metadata.context] - The context of the operation
* @param {string} [params.iv] - (Optional) Base64-encoded initialization vector for encryption.
* @param {string} [params.authTag] - (Optional) Base64-encoded authentication tag from AES-GCM.
* @param {string} [params.encryptedKey] - (Optional) Base64-encoded AES key encrypted with RSA.
* @param {Object} [metadata] - Additional metadata for this operation.
* @param {string} [metadata.context] - The context of the operation.
* @returns {Promise<TMessage>} The updated or newly inserted message document.
* @throws {Error} If there is an error in saving the message.
*/
Expand All @@ -51,6 +55,9 @@ async function saveMessage(req, params, metadata) {
...params,
user: req.user.id,
messageId: params.newMessageId || params.messageId,
iv: params.iv ?? null,
authTag: params.authTag ?? null,
encryptedKey: params.encryptedKey ?? null,
};

if (req?.body?.isTemporary) {
Expand Down Expand Up @@ -90,7 +97,12 @@ async function bulkSaveMessages(messages, overrideTimestamp = false) {
const bulkOps = messages.map((message) => ({
updateOne: {
filter: { messageId: message.messageId },
update: message,
update: {
...message,
iv: message.iv ?? null,
authTag: message.authTag ?? null,
encryptedKey: message.encryptedKey ?? null,
},
timestamps: !overrideTimestamp,
upsert: true,
},
Expand Down Expand Up @@ -119,14 +131,7 @@ async function bulkSaveMessages(messages, overrideTimestamp = false) {
* @returns {Promise<Object>} The updated or newly inserted message document.
* @throws {Error} If there is an error in saving the message.
*/
async function recordMessage({
user,
endpoint,
messageId,
conversationId,
parentMessageId,
...rest
}) {
async function recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) {
try {
// No parsing of convoId as may use threadId
const message = {
Expand All @@ -136,6 +141,9 @@ async function recordMessage({
conversationId,
parentMessageId,
...rest,
iv: rest.iv ?? null,
authTag: rest.authTag ?? null,
encryptedKey: rest.encryptedKey ?? null,
};

return await Message.findOneAndUpdate({ user, messageId }, message, {
Expand Down Expand Up @@ -190,12 +198,15 @@ async function updateMessageText(req, { messageId, text }) {
async function updateMessage(req, message, metadata) {
try {
const { messageId, ...update } = message;
// Ensure encryption fields are explicitly updated (if provided)
update.iv = update.iv ?? null;
update.authTag = update.authTag ?? null;
update.encryptedKey = update.encryptedKey ?? null;

const updatedMessage = await Message.findOneAndUpdate(
{ messageId, user: req.user.id },
update,
{
new: true,
},
{ new: true },
);

if (!updatedMessage) {
Expand Down Expand Up @@ -225,11 +236,11 @@ async function updateMessage(req, message, metadata) {
*
* @async
* @function deleteMessagesSince
* @param {Object} params - The parameters object.
* @param {Object} req - The request object.
* @param {Object} params - The parameters object.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.conversationId - The identifier of the conversation.
* @returns {Promise<Number>} The number of deleted messages.
* @returns {Promise<number>} The number of deleted messages.
* @throws {Error} If there is an error in deleting messages.
*/
async function deleteMessagesSince(req, { messageId, conversationId }) {
Expand Down Expand Up @@ -263,7 +274,6 @@ async function getMessages(filter, select) {
if (select) {
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
}

return await Message.find(filter).sort({ createdAt: 1 }).lean();
} catch (err) {
logger.error('Error getting messages:', err);
Expand All @@ -281,10 +291,7 @@ async function getMessages(filter, select) {
*/
async function getMessage({ user, messageId }) {
try {
return await Message.findOne({
user,
messageId,
}).lean();
return await Message.findOne({ user, messageId }).lean();
} catch (err) {
logger.error('Error getting message:', err);
throw err;
Expand Down
12 changes: 12 additions & 0 deletions api/models/schema/messageSchema.js
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ const messageSchema = mongoose.Schema(
expiredAt: {
type: Date,
},
iv: {
type: String,
default: null,
},
authTag: {
type: String,
default: null,
},
encryptedKey: {
type: String,
default: null,
},
},
{ timestamps: true },
);
Expand Down
20 changes: 20 additions & 0 deletions api/models/schema/userSchema.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ const { SystemRoles } = require('librechat-data-provider');
* @property {Array} [plugins=[]] - List of plugins used by the user
* @property {Array.<MongoSession>} [refreshToken] - List of sessions with refresh tokens
* @property {Date} [expiresAt] - Optional expiration date of the file
* @property {string} [encryptionPublicKey] - The user's encryption public key
* @property {string} [encryptedPrivateKey] - The user's encrypted private key
* @property {string} [encryptionSalt] - The salt used for key derivation (e.g., PBKDF2)
* @property {string} [encryptionIV] - The IV used for encrypting the private key
* @property {Date} [createdAt] - Date when the user was created (added by timestamps)
* @property {Date} [updatedAt] - Date when the user was last updated (added by timestamps)
*/
Expand Down Expand Up @@ -143,6 +147,22 @@ const userSchema = mongoose.Schema(
type: Boolean,
default: false,
},
encryptionPublicKey: {
type: String,
default: null,
},
encryptedPrivateKey: {
type: String,
default: null,
},
encryptionSalt: {
type: String,
default: null,
},
encryptionIV: {
type: String,
default: null,
},
},

{ timestamps: true },
Expand Down
Loading
Loading