Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions apps/gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,19 @@ async fn main() -> Result<()> {
// Cloud: KMS envelope decryption (calls KMS Decrypt for each data key)
let crypto = Arc::new(crypto::CryptoService::from_env().await?);

let policy_engine = Arc::new(PolicyEngine { pool, crypto });
let policy_engine = Arc::new(PolicyEngine {
pool,
crypto: Arc::clone(&crypto),
});

// Initialize vault service with Bitwarden provider
let proxy_url = std::env::var("BITWARDEN_PROXY_URL")
.unwrap_or_else(|_| "wss://ap.lesspassword.dev".to_string());
let bitwarden =
BitwardenVaultProvider::new(BitwardenConfig { proxy_url }, policy_engine.pool.clone());
let bitwarden = BitwardenVaultProvider::new(
BitwardenConfig { proxy_url },
policy_engine.pool.clone(),
Arc::clone(&crypto),
);
let vault_service = Arc::new(VaultService::new(
vec![Box::new(bitwarden)],
policy_engine.pool.clone(),
Expand Down
28 changes: 21 additions & 7 deletions apps/gateway/src/vault/bitwarden.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ use sqlx::PgPool;
use tokio::sync::{mpsc, Mutex};
use tracing::{info, warn};

use super::bitwarden_db::{BitwardenConnectionStore, BitwardenIdentityProvider};
use super::bitwarden_db::{
decrypt_connection_data, encrypt_connection_data, BitwardenConnectionStore,
BitwardenIdentityProvider,
};
use super::{PairResult, ProviderStatus, VaultCredential, VaultProvider};
use crate::crypto::CryptoService;
use crate::db;

/// Parse a hex-encoded fingerprint string into an `IdentityFingerprint`.
Expand Down Expand Up @@ -140,16 +144,18 @@ pub(crate) struct BitwardenConfig {
pub(crate) struct BitwardenVaultProvider {
config: BitwardenConfig,
pool: PgPool,
crypto: Arc<CryptoService>,
sessions: Arc<DashMap<String, Arc<BitwardenUserSession>>>,
}

impl BitwardenVaultProvider {
pub fn new(config: BitwardenConfig, pool: PgPool) -> Self {
pub fn new(config: BitwardenConfig, pool: PgPool, crypto: Arc<CryptoService>) -> Self {
let sessions = Arc::new(DashMap::new());
Self::spawn_eviction_task(Arc::clone(&sessions));
Self {
config,
pool,
crypto,
sessions,
}
}
Expand Down Expand Up @@ -204,10 +210,16 @@ impl BitwardenVaultProvider {
None => return Ok(None),
};

let cd: Option<BitwardenConnectionData> = row
.connection_data
.as_ref()
.and_then(|v| serde_json::from_value(v.clone()).ok());
let cd: Option<BitwardenConnectionData> = match row.connection_data.as_ref() {
Some(v) => match decrypt_connection_data(&self.crypto, v).await {
Ok(cd) => Some(cd),
Err(e) => {
warn!(error = %e, account_id, "failed to decrypt vault connection data");
None
}
},
None => None,
};

let key_data = cd.as_ref().and_then(|c| c.key_data.as_ref());
let identity = match key_data {
Expand Down Expand Up @@ -265,6 +277,7 @@ impl BitwardenVaultProvider {
self.pool.clone(),
account_id.to_string(),
key_data,
Arc::clone(&self.crypto),
session.connection_data.as_ref(),
);

Expand Down Expand Up @@ -390,12 +403,13 @@ impl VaultProvider for BitwardenVaultProvider {
key_data: Some(session.identity.to_cose()),
transport_state: None,
};
let encrypted_cd = encrypt_connection_data(&self.crypto, &initial_cd).await?;
db::upsert_vault_connection(
&self.pool,
account_id,
"bitwarden",
"paired",
Some(&serde_json::to_value(&initial_cd)?),
Some(&encrypted_cd),
)
.await?;

Expand Down
143 changes: 134 additions & 9 deletions apps/gateway/src/vault/bitwarden_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
//! Instead of files, identity keypair and connection/transport state are stored in
//! the `VaultConnection.connectionData` JSON column (scoped to `provider = "bitwarden"`).

use std::sync::Arc;

use anyhow::Context;
use ap_client::{
ClientError, ConnectionInfo, ConnectionStore, ConnectionUpdate, IdentityFingerprint,
IdentityProvider,
Expand All @@ -14,6 +17,7 @@ use sqlx::PgPool;
use tracing::warn;

use super::bitwarden::{parse_fingerprint, BitwardenConnectionData};
use crate::crypto::CryptoService;
use crate::db;

// ── BitwardenIdentityProvider ───────────────────────────────────────────
Expand Down Expand Up @@ -69,11 +73,14 @@ impl IdentityProvider for BitwardenIdentityProvider {
/// DB-backed connection store, scoped to a single `account_id` with `provider = "bitwarden"`.
///
/// Connections are cached in memory. Writes go through to the DB directly via async calls.
/// Sensitive fields (`key_data`, `transport_state`) are encrypted with AES-256-GCM before
/// being written to the `connection_data` JSON column.
pub(crate) struct BitwardenConnectionStore {
pool: PgPool,
account_id: String,
/// COSE-encoded keypair — kept here so write-throughs don't null out key_data in DB.
key_data: Option<Vec<u8>>,
crypto: Arc<CryptoService>,
/// In-memory connection (at most one per user for Bitwarden).
connection: Option<ConnectionInfo>,
}
Expand All @@ -84,6 +91,7 @@ impl BitwardenConnectionStore {
pool: PgPool,
account_id: String,
key_data: Option<Vec<u8>>,
crypto: Arc<CryptoService>,
connection_data: Option<&BitwardenConnectionData>,
) -> Self {
let connection = connection_data.and_then(|cd| {
Expand All @@ -107,16 +115,17 @@ impl BitwardenConnectionStore {
pool,
account_id,
key_data,
crypto,
connection,
}
}

/// Persist the current connection data to DB.
/// Persist the current connection data to DB (encrypted).
async fn write_through(&self, cd: &BitwardenConnectionData) {
let json = match serde_json::to_value(cd) {
let json = match encrypt_connection_data(&self.crypto, cd).await {
Ok(v) => v,
Err(e) => {
warn!(error = %e, "failed to serialize BitwardenConnectionData");
warn!(error = %e, "failed to encrypt connection data for write-through");
return;
}
};
Expand Down Expand Up @@ -179,10 +188,61 @@ fn now_timestamp() -> u64 {
.as_secs()
}

// ── Connection data encryption ────────────────────────────────────────

/// Encrypt a `BitwardenConnectionData` for DB storage.
/// Returns a JSON value `{"encrypted": "iv:authTag:ciphertext"}`.
pub(super) async fn encrypt_connection_data(
crypto: &CryptoService,
cd: &BitwardenConnectionData,
) -> anyhow::Result<serde_json::Value> {
let json_str = serde_json::to_string(cd).context("serializing connection data")?;
let encrypted = crypto
.encrypt(&json_str)
.await
.context("encrypting connection data")?;
Ok(serde_json::json!({ "encrypted": encrypted }))
}

/// Decrypt a `connection_data` JSON value from the DB.
/// Supports both encrypted (`{"encrypted": "..."}`) and legacy plaintext formats.
/// Legacy rows are transparently upgraded to encrypted on next write-through.
pub(super) async fn decrypt_connection_data(
crypto: &CryptoService,
value: &serde_json::Value,
) -> anyhow::Result<BitwardenConnectionData> {
if let Some(encrypted_str) = value.get("encrypted").and_then(|v| v.as_str()) {
let json_str = crypto
.decrypt(encrypted_str)
.await
.context("decrypting connection data")?;
serde_json::from_str(&json_str).context("deserializing decrypted connection data")
} else {
// Legacy: unencrypted connection data — will be encrypted on next write-through
serde_json::from_value(value.clone()).context("deserializing legacy connection data")
}
}

#[cfg(test)]
mod tests {
use super::*;

// ── Test helpers ──────────────────────────────────────────────────

fn test_crypto() -> Arc<CryptoService> {
use base64::Engine;
use ring::rand::{SecureRandom, SystemRandom};
let rng = SystemRandom::new();
let mut key = [0u8; 32];
rng.fill(&mut key).expect("generate random key");
let key_b64 = base64::engine::general_purpose::STANDARD.encode(key);
Arc::new(CryptoService::from_base64_key(&key_b64).expect("create test crypto"))
}

fn fake_pool() -> PgPool {
sqlx::PgPool::connect_lazy("postgres://fake").expect("lazy pool")
}

// ── BitwardenIdentityProvider ──────────────────────────────────────

#[test]
Expand All @@ -207,13 +267,10 @@ mod tests {

// ── BitwardenConnectionStore construction ─────────────────────────

fn fake_pool() -> PgPool {
sqlx::PgPool::connect_lazy("postgres://fake").expect("lazy pool")
}

#[tokio::test]
async fn connection_store_new_without_data() {
let store = BitwardenConnectionStore::new(fake_pool(), "user1".into(), None, None);
let store =
BitwardenConnectionStore::new(fake_pool(), "user1".into(), None, test_crypto(), None);
assert!(store.connection.is_none());
}

Expand All @@ -229,6 +286,7 @@ mod tests {
fake_pool(),
"user1".into(),
Some(vec![1, 2, 3]),
test_crypto(),
Some(&cd),
);

Expand All @@ -243,7 +301,13 @@ mod tests {
key_data: Some(vec![1]),
transport_state: None,
};
let store = BitwardenConnectionStore::new(fake_pool(), "user1".into(), None, Some(&cd));
let store = BitwardenConnectionStore::new(
fake_pool(),
"user1".into(),
None,
test_crypto(),
Some(&cd),
);
assert!(store.connection.is_none());
}

Expand All @@ -254,6 +318,7 @@ mod tests {
fake_pool(),
"user1".into(),
Some(key_data.clone()),
test_crypto(),
None,
);

Expand All @@ -267,4 +332,64 @@ mod tests {
let cd = store.to_connection_data(&info);
assert_eq!(cd.key_data, Some(key_data));
}

// ── Connection data encryption ────────────────────────────────────

#[tokio::test]
async fn encrypt_decrypt_connection_data_round_trip() {
let crypto = test_crypto();
let cd = BitwardenConnectionData {
fingerprint: Some(hex::encode([42u8; 32])),
key_data: Some(vec![1, 2, 3, 4, 5]),
transport_state: Some(vec![10, 20, 30]),
};

let encrypted = encrypt_connection_data(&crypto, &cd)
.await
.expect("encrypt");
assert!(
encrypted.get("encrypted").is_some(),
"should have encrypted wrapper"
);

let decrypted = decrypt_connection_data(&crypto, &encrypted)
.await
.expect("decrypt");
assert_eq!(decrypted.fingerprint, cd.fingerprint);
assert_eq!(decrypted.key_data, cd.key_data);
assert_eq!(decrypted.transport_state, cd.transport_state);
}

#[tokio::test]
async fn decrypt_legacy_plaintext_connection_data() {
let crypto = test_crypto();
let cd = BitwardenConnectionData {
fingerprint: Some("abc123".into()),
key_data: Some(vec![1, 2, 3]),
transport_state: None,
};
let plaintext_json = serde_json::to_value(&cd).expect("serialize");

let decrypted = decrypt_connection_data(&crypto, &plaintext_json)
.await
.expect("should handle legacy plaintext");
assert_eq!(decrypted.fingerprint, cd.fingerprint);
assert_eq!(decrypted.key_data, cd.key_data);
}

#[tokio::test]
async fn decrypt_with_wrong_key_fails() {
let crypto1 = test_crypto();
let crypto2 = test_crypto();
let cd = BitwardenConnectionData {
fingerprint: Some("test".into()),
key_data: Some(vec![1]),
transport_state: None,
};

let encrypted = encrypt_connection_data(&crypto1, &cd)
.await
.expect("encrypt");
assert!(decrypt_connection_data(&crypto2, &encrypted).await.is_err());
}
}
Loading