diff --git a/lightning/src/ln/chanmon_update_fail_tests.rs b/lightning/src/ln/chanmon_update_fail_tests.rs index ce07cb73713..8d9ce533366 100644 --- a/lightning/src/ln/chanmon_update_fail_tests.rs +++ b/lightning/src/ln/chanmon_update_fail_tests.rs @@ -3598,3 +3598,117 @@ fn test_glacial_peer_cant_hang() { do_test_glacial_peer_cant_hang(false); do_test_glacial_peer_cant_hang(true); } + +#[test] +fn test_partial_claim_mon_update_compl_actions() { + // Test that if we have an MPP claim that we ensure the preimage for the claim is retained in + // all the `ChannelMonitor`s until the preimage reaches every `ChannelMonitor` for a channel + // which was a part of the MPP. + let chanmon_cfgs = create_chanmon_cfgs(4); + let node_cfgs = create_node_cfgs(4, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]); + let mut nodes = create_network(4, &node_cfgs, &node_chanmgrs); + + let chan_1_scid = create_announced_chan_between_nodes(&nodes, 0, 1).0.contents.short_channel_id; + let chan_2_scid = create_announced_chan_between_nodes(&nodes, 0, 2).0.contents.short_channel_id; + let (chan_3_update, _, chan_3_id, ..) = create_announced_chan_between_nodes(&nodes, 1, 3); + let chan_3_scid = chan_3_update.contents.short_channel_id; + let (chan_4_update, _, chan_4_id, ..) = create_announced_chan_between_nodes(&nodes, 2, 3); + let chan_4_scid = chan_4_update.contents.short_channel_id; + + let (mut route, payment_hash, preimage, payment_secret) = get_route_and_payment_hash!(&nodes[0], nodes[3], 100000); + let path = route.paths[0].clone(); + route.paths.push(path); + route.paths[0].hops[0].pubkey = nodes[1].node.get_our_node_id(); + route.paths[0].hops[0].short_channel_id = chan_1_scid; + route.paths[0].hops[1].short_channel_id = chan_3_scid; + route.paths[1].hops[0].pubkey = nodes[2].node.get_our_node_id(); + route.paths[1].hops[0].short_channel_id = chan_2_scid; + route.paths[1].hops[1].short_channel_id = chan_4_scid; + send_along_route_with_secret(&nodes[0], route, &[&[&nodes[1], &nodes[3]], &[&nodes[2], &nodes[3]]], 200_000, payment_hash, payment_secret); + + // Claim along both paths, but only complete one of the two monitor updates. + chanmon_cfgs[3].persister.set_update_ret(ChannelMonitorUpdateStatus::InProgress); + chanmon_cfgs[3].persister.set_update_ret(ChannelMonitorUpdateStatus::InProgress); + nodes[3].node.claim_funds(preimage); + assert_eq!(nodes[3].node.get_and_clear_pending_msg_events(), Vec::new()); + assert_eq!(nodes[3].node.get_and_clear_pending_events(), Vec::new()); + check_added_monitors(&nodes[3], 2); + + // Complete the 1<->3 monitor update and play the commitment_signed dance forward until it + // blocks. + nodes[3].chain_monitor.complete_sole_pending_chan_update(&chan_3_id); + expect_payment_claimed!(&nodes[3], payment_hash, 200_000); + let updates = get_htlc_update_msgs(&nodes[3], &nodes[1].node.get_our_node_id()); + + nodes[1].node.handle_update_fulfill_htlc(&nodes[3].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]); + check_added_monitors(&nodes[1], 1); + expect_payment_forwarded!(nodes[1], nodes[0], nodes[3], Some(1000), false, false); + let _bs_updates_for_a = get_htlc_update_msgs(&nodes[1], &nodes[0].node.get_our_node_id()); + + nodes[1].node.handle_commitment_signed(&nodes[3].node.get_our_node_id(), &updates.commitment_signed); + check_added_monitors(&nodes[1], 1); + let (bs_raa, bs_cs) = get_revoke_commit_msgs(&nodes[1], &nodes[3].node.get_our_node_id()); + + nodes[3].node.handle_revoke_and_ack(&nodes[1].node.get_our_node_id(), &bs_raa); + check_added_monitors(&nodes[3], 0); + + nodes[3].node.handle_commitment_signed(&nodes[1].node.get_our_node_id(), &bs_cs); + check_added_monitors(&nodes[3], 0); + assert!(nodes[3].node.get_and_clear_pending_msg_events().is_empty()); + + // Now double-check that the preimage is still in the 1<->3 channel and complete the pending + // monitor update, allowing node 3 to claim the payment on the 2<->3 channel. This also + // unblocks the 1<->3 channel, allowing node 3 to release the two blocked monitor updates and + // respond to the final commitment_signed. + assert!(get_monitor!(nodes[3], chan_3_id).get_stored_preimages().contains_key(&payment_hash)); + + nodes[3].chain_monitor.complete_sole_pending_chan_update(&chan_4_id); + let mut ds_msgs = nodes[3].node.get_and_clear_pending_msg_events(); + assert_eq!(ds_msgs.len(), 2); + check_added_monitors(&nodes[3], 2); + + match remove_first_msg_event_to_node(&nodes[1].node.get_our_node_id(), &mut ds_msgs) { + MessageSendEvent::SendRevokeAndACK { msg, .. } => { + nodes[1].node.handle_revoke_and_ack(&nodes[3].node.get_our_node_id(), &msg); + check_added_monitors(&nodes[1], 1); + } + _ => panic!(), + } + + match remove_first_msg_event_to_node(&nodes[2].node.get_our_node_id(), &mut ds_msgs) { + MessageSendEvent::UpdateHTLCs { updates, .. } => { + nodes[2].node.handle_update_fulfill_htlc(&nodes[3].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]); + check_added_monitors(&nodes[2], 1); + expect_payment_forwarded!(nodes[2], nodes[0], nodes[3], Some(1000), false, false); + let _cs_updates_for_a = get_htlc_update_msgs(&nodes[2], &nodes[0].node.get_our_node_id()); + + nodes[2].node.handle_commitment_signed(&nodes[3].node.get_our_node_id(), &updates.commitment_signed); + check_added_monitors(&nodes[2], 1); + }, + _ => panic!(), + } + + let (cs_raa, cs_cs) = get_revoke_commit_msgs(&nodes[2], &nodes[3].node.get_our_node_id()); + + nodes[3].node.handle_revoke_and_ack(&nodes[2].node.get_our_node_id(), &cs_raa); + check_added_monitors(&nodes[3], 1); + + nodes[3].node.handle_commitment_signed(&nodes[2].node.get_our_node_id(), &cs_cs); + check_added_monitors(&nodes[3], 1); + + let ds_raa = get_event_msg!(nodes[3], MessageSendEvent::SendRevokeAndACK, nodes[2].node.get_our_node_id()); + nodes[2].node.handle_revoke_and_ack(&nodes[3].node.get_our_node_id(), &ds_raa); + check_added_monitors(&nodes[2], 1); + + // Our current `ChannelMonitor`s store preimages one RAA longer than they need to. That's nice + // for safety, but means we have to send one more payment here to wipe the preimage. + assert!(get_monitor!(nodes[3], chan_3_id).get_stored_preimages().contains_key(&payment_hash)); + assert!(get_monitor!(nodes[3], chan_4_id).get_stored_preimages().contains_key(&payment_hash)); + + send_payment(&nodes[1], &[&nodes[3]], 100_000); + assert!(!get_monitor!(nodes[3], chan_3_id).get_stored_preimages().contains_key(&payment_hash)); + + send_payment(&nodes[2], &[&nodes[3]], 100_000); + assert!(!get_monitor!(nodes[3], chan_4_id).get_stored_preimages().contains_key(&payment_hash)); +} diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 6ba7396ebfe..67159ef52f3 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -306,6 +306,7 @@ pub(super) struct PendingAddHTLCInfo { // Note that this may be an outbound SCID alias for the associated channel. prev_short_channel_id: u64, prev_htlc_id: u64, + prev_counterparty_node_id: Option, prev_channel_id: ChannelId, prev_funding_outpoint: OutPoint, prev_user_channel_id: u128, @@ -349,9 +350,10 @@ pub(crate) struct HTLCPreviousHopData { blinded_failure: Option, channel_id: ChannelId, - // This field is consumed by `claim_funds_from_hop()` when updating a force-closed backwards + // These fields are consumed by `claim_funds_from_hop()` when updating a force-closed backwards // channel with a preimage provided by the forward channel. outpoint: OutPoint, + counterparty_node_id: Option, } enum OnionPayload { @@ -755,13 +757,55 @@ enum BackgroundEvent { }, } +/// A pointer to a channel that is unblocked when an event is surfaced +#[derive(Debug)] +pub(crate) struct EventUnblockedChannel { + counterparty_node_id: PublicKey, + funding_txo: OutPoint, + channel_id: ChannelId, + blocking_action: RAAMonitorUpdateBlockingAction, +} + +impl Writeable for EventUnblockedChannel { + fn write(&self, writer: &mut W) -> Result<(), io::Error> { + self.counterparty_node_id.write(writer)?; + self.funding_txo.write(writer)?; + self.channel_id.write(writer)?; + self.blocking_action.write(writer) + } +} + +impl MaybeReadable for EventUnblockedChannel { + fn read(reader: &mut R) -> Result, DecodeError> { + let counterparty_node_id = Readable::read(reader)?; + let funding_txo = Readable::read(reader)?; + let channel_id = Readable::read(reader)?; + let blocking_action = match RAAMonitorUpdateBlockingAction::read(reader)? { + Some(blocking_action) => blocking_action, + None => return Ok(None), + }; + Ok(Some(EventUnblockedChannel { + counterparty_node_id, + funding_txo, + channel_id, + blocking_action, + })) + } +} + #[derive(Debug)] pub(crate) enum MonitorUpdateCompletionAction { /// Indicates that a payment ultimately destined for us was claimed and we should emit an /// [`events::Event::PaymentClaimed`] to the user if we haven't yet generated such an event for /// this payment. Note that this is only best-effort. On restart it's possible such a duplicate /// event can be generated. - PaymentClaimed { payment_hash: PaymentHash }, + PaymentClaimed { + payment_hash: PaymentHash, + /// A pending MPP claim which hasn't yet completed. + /// + /// Not written to disk. + pending_mpp_claim: Option<(PublicKey, ChannelId, u64, PendingMPPClaimPointer)>, + }, /// Indicates an [`events::Event`] should be surfaced to the user and possibly resume the /// operation of another channel. /// @@ -772,7 +816,7 @@ pub(crate) enum MonitorUpdateCompletionAction { /// outbound edge. EmitEventAndFreeOtherChannel { event: events::Event, - downstream_counterparty_and_funding_outpoint: Option<(PublicKey, OutPoint, ChannelId, RAAMonitorUpdateBlockingAction)>, + downstream_counterparty_and_funding_outpoint: Option, }, /// Indicates we should immediately resume the operation of another channel, unless there is /// some other reason why the channel is blocked. In practice this simply means immediately @@ -795,13 +839,16 @@ pub(crate) enum MonitorUpdateCompletionAction { } impl_writeable_tlv_based_enum_upgradable!(MonitorUpdateCompletionAction, - (0, PaymentClaimed) => { (0, payment_hash, required) }, + (0, PaymentClaimed) => { + (0, payment_hash, required), + (9999999999, pending_mpp_claim, (static_value, None)), + }, // Note that FreeOtherChannelImmediately should never be written - we were supposed to free // *immediately*. However, for simplicity we implement read/write here. (1, FreeOtherChannelImmediately) => { (0, downstream_counterparty_node_id, required), (2, downstream_funding_outpoint, required), - (4, blocking_action, required), + (4, blocking_action, upgradable_required), // Note that by the time we get past the required read above, downstream_funding_outpoint will be // filled in, so we can safely unwrap it here. (5, downstream_channel_id, (default_value, ChannelId::v1_from_funding_outpoint(downstream_funding_outpoint.0.unwrap()))), @@ -813,7 +860,7 @@ impl_writeable_tlv_based_enum_upgradable!(MonitorUpdateCompletionAction, // monitor updates which aren't properly blocked or resumed, however that's fine - we don't // support async monitor updates even in LDK 0.0.116 and once we do we'll require no // downgrades to prior versions. - (1, downstream_counterparty_and_funding_outpoint, option), + (1, downstream_counterparty_and_funding_outpoint, upgradable_option), }, ); @@ -835,6 +882,26 @@ impl_writeable_tlv_based_enum!(EventCompletionAction, }; ); +#[derive(Debug)] +pub(crate) struct PendingMPPClaim { + channels_without_preimage: Vec<(PublicKey, OutPoint, ChannelId, u64)>, + channels_with_preimage: Vec<(PublicKey, OutPoint, ChannelId)>, +} + +#[derive(Clone)] +pub(crate) struct PendingMPPClaimPointer(Arc>); + +impl PartialEq for PendingMPPClaimPointer { + fn eq(&self, o: &Self) -> bool { Arc::ptr_eq(&self.0, &o.0) } +} +impl Eq for PendingMPPClaimPointer {} + +impl core::fmt::Debug for PendingMPPClaimPointer { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { + self.0.lock().unwrap().fmt(f) + } +} + #[derive(Clone, PartialEq, Eq, Debug)] /// If something is blocked on the completion of an RAA-generated [`ChannelMonitorUpdate`] we track /// the blocked action here. See enum variants for more info. @@ -848,6 +915,16 @@ pub(crate) enum RAAMonitorUpdateBlockingAction { /// The HTLC ID on the inbound edge. htlc_id: u64, }, + /// We claimed an MPP payment across multiple channels. We have to block removing the payment + /// preimage from any monitor until the last monitor is updated to contain the payment + /// preimage. Otherwise we may not be able to replay the preimage on the monitor(s) that + /// weren't updated on startup. + /// + /// This variant is *not* written to disk, instead being inferred from [`ChannelMonitor`] + /// state. + ClaimedMPPPayment { + pending_claim: PendingMPPClaimPointer, + } } impl RAAMonitorUpdateBlockingAction { @@ -859,10 +936,16 @@ impl RAAMonitorUpdateBlockingAction { } } -impl_writeable_tlv_based_enum!(RAAMonitorUpdateBlockingAction, - (0, ForwardedPaymentInboundClaim) => { (0, channel_id, required), (2, htlc_id, required) } -;); +impl_writeable_tlv_based_enum_upgradable!(RAAMonitorUpdateBlockingAction, + (0, ForwardedPaymentInboundClaim) => { (0, channel_id, required), (2, htlc_id, required) }, + unread_variants: ClaimedMPPPayment +); +impl Readable for Option { + fn read(reader: &mut R) -> Result { + Ok(RAAMonitorUpdateBlockingAction::read(reader)?) + } +} /// State we hold per-peer. pub(super) struct PeerState where SP::Target: SignerProvider { @@ -4692,6 +4775,7 @@ where let mut per_source_pending_forward = [( payment.prev_short_channel_id, + payment.prev_counterparty_node_id, payment.prev_funding_outpoint, payment.prev_channel_id, payment.prev_user_channel_id, @@ -4722,6 +4806,7 @@ where user_channel_id: Some(payment.prev_user_channel_id), outpoint: payment.prev_funding_outpoint, channel_id: payment.prev_channel_id, + counterparty_node_id: payment.prev_counterparty_node_id, htlc_id: payment.prev_htlc_id, incoming_packet_shared_secret: payment.forward_info.incoming_shared_secret, phantom_shared_secret: None, @@ -4851,8 +4936,10 @@ where // Process all of the forwards and failures for the channel in which the HTLCs were // proposed to as a batch. - let pending_forwards = (incoming_scid, incoming_funding_txo, incoming_channel_id, - incoming_user_channel_id, htlc_forwards.drain(..).collect()); + let pending_forwards = ( + incoming_scid, Some(incoming_counterparty_node_id), incoming_funding_txo, + incoming_channel_id, incoming_user_channel_id, htlc_forwards.drain(..).collect() + ); self.forward_htlcs_without_forward_event(&mut [pending_forwards]); for (htlc_fail, htlc_destination) in htlc_fails.drain(..) { let failure = match htlc_fail { @@ -4886,7 +4973,7 @@ where let mut new_events = VecDeque::new(); let mut failed_forwards = Vec::new(); - let mut phantom_receives: Vec<(u64, OutPoint, ChannelId, u128, Vec<(PendingHTLCInfo, u64)>)> = Vec::new(); + let mut phantom_receives: Vec<(u64, Option, OutPoint, ChannelId, u128, Vec<(PendingHTLCInfo, u64)>)> = Vec::new(); { let mut forward_htlcs = new_hash_map(); mem::swap(&mut forward_htlcs, &mut self.forward_htlcs.lock().unwrap()); @@ -4900,7 +4987,7 @@ where match forward_info { HTLCForwardInfo::AddHTLC(PendingAddHTLCInfo { prev_short_channel_id, prev_htlc_id, prev_channel_id, prev_funding_outpoint, - prev_user_channel_id, forward_info: PendingHTLCInfo { + prev_user_channel_id, prev_counterparty_node_id, forward_info: PendingHTLCInfo { routing, incoming_shared_secret, payment_hash, outgoing_amt_msat, outgoing_cltv_value, .. } @@ -4915,6 +5002,7 @@ where user_channel_id: Some(prev_user_channel_id), channel_id: prev_channel_id, outpoint: prev_funding_outpoint, + counterparty_node_id: prev_counterparty_node_id, htlc_id: prev_htlc_id, incoming_packet_shared_secret: incoming_shared_secret, phantom_shared_secret: $phantom_ss, @@ -4977,7 +5065,10 @@ where outgoing_cltv_value, Some(phantom_shared_secret), false, None, current_height, self.default_configuration.accept_mpp_keysend) { - Ok(info) => phantom_receives.push((prev_short_channel_id, prev_funding_outpoint, prev_channel_id, prev_user_channel_id, vec![(info, prev_htlc_id)])), + Ok(info) => phantom_receives.push(( + prev_short_channel_id, prev_counterparty_node_id, prev_funding_outpoint, + prev_channel_id, prev_user_channel_id, vec![(info, prev_htlc_id)] + )), Err(InboundHTLCErr { err_code, err_data, msg }) => failed_payment!(msg, err_code, err_data, Some(phantom_shared_secret)) } }, @@ -5022,7 +5113,7 @@ where let queue_fail_htlc_res = match forward_info { HTLCForwardInfo::AddHTLC(PendingAddHTLCInfo { prev_short_channel_id, prev_htlc_id, prev_channel_id, prev_funding_outpoint, - prev_user_channel_id, forward_info: PendingHTLCInfo { + prev_user_channel_id, prev_counterparty_node_id, forward_info: PendingHTLCInfo { incoming_shared_secret, payment_hash, outgoing_amt_msat, outgoing_cltv_value, routing: PendingHTLCRouting::Forward { ref onion_packet, blinded, .. @@ -5032,6 +5123,7 @@ where let htlc_source = HTLCSource::PreviousHopData(HTLCPreviousHopData { short_channel_id: prev_short_channel_id, user_channel_id: Some(prev_user_channel_id), + counterparty_node_id: prev_counterparty_node_id, channel_id: prev_channel_id, outpoint: prev_funding_outpoint, htlc_id: prev_htlc_id, @@ -5160,7 +5252,7 @@ where match forward_info { HTLCForwardInfo::AddHTLC(PendingAddHTLCInfo { prev_short_channel_id, prev_htlc_id, prev_channel_id, prev_funding_outpoint, - prev_user_channel_id, forward_info: PendingHTLCInfo { + prev_user_channel_id, prev_counterparty_node_id, forward_info: PendingHTLCInfo { routing, incoming_shared_secret, payment_hash, incoming_amt_msat, outgoing_amt_msat, skimmed_fee_msat, .. } @@ -5198,6 +5290,7 @@ where prev_hop: HTLCPreviousHopData { short_channel_id: prev_short_channel_id, user_channel_id: Some(prev_user_channel_id), + counterparty_node_id: prev_counterparty_node_id, channel_id: prev_channel_id, outpoint: prev_funding_outpoint, htlc_id: prev_htlc_id, @@ -5230,6 +5323,7 @@ where failed_forwards.push((HTLCSource::PreviousHopData(HTLCPreviousHopData { short_channel_id: $htlc.prev_hop.short_channel_id, user_channel_id: $htlc.prev_hop.user_channel_id, + counterparty_node_id: $htlc.prev_hop.counterparty_node_id, channel_id: prev_channel_id, outpoint: prev_funding_outpoint, htlc_id: $htlc.prev_hop.htlc_id, @@ -6115,7 +6209,7 @@ where let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(self); - let mut sources = { + let sources = { let mut claimable_payments = self.claimable_payments.lock().unwrap(); if let Some(payment) = claimable_payments.claimable_payments.remove(&payment_hash) { let mut receiver_node_id = self.our_network_pubkey; @@ -6210,18 +6304,46 @@ where return; } if valid_mpp { - for htlc in sources.drain(..) { + let pending_mpp_claim_ptr_opt = if sources.len() > 1 { + let channels_without_preimage = sources.iter().filter_map(|htlc| { + if let Some(cp_id) = htlc.prev_hop.counterparty_node_id { + let prev_hop = &htlc.prev_hop; + Some((cp_id, prev_hop.outpoint, prev_hop.channel_id, prev_hop.htlc_id)) + } else { + None + } + }).collect(); + Some(Arc::new(Mutex::new(PendingMPPClaim { + channels_without_preimage, + channels_with_preimage: Vec::new(), + }))) + } else { + None + }; + for htlc in sources { + let this_mpp_claim = pending_mpp_claim_ptr_opt.as_ref().and_then(|pending_mpp_claim| + if let Some(cp_id) = htlc.prev_hop.counterparty_node_id { + let claim_ptr = PendingMPPClaimPointer(Arc::clone(pending_mpp_claim)); + Some((cp_id, htlc.prev_hop.channel_id, htlc.prev_hop.htlc_id, claim_ptr)) + } else { + None + } + ); + let raa_blocker = pending_mpp_claim_ptr_opt.as_ref().map(|pending_claim| { + RAAMonitorUpdateBlockingAction::ClaimedMPPPayment { + pending_claim: PendingMPPClaimPointer(Arc::clone(pending_claim)), + } + }); self.claim_funds_from_hop( htlc.prev_hop, payment_preimage, |_, definitely_duplicate| { debug_assert!(!definitely_duplicate, "We shouldn't claim duplicatively from a payment"); - Some(MonitorUpdateCompletionAction::PaymentClaimed { payment_hash }) + (Some(MonitorUpdateCompletionAction::PaymentClaimed { payment_hash, pending_mpp_claim: this_mpp_claim }), raa_blocker) } ); } - } - if !valid_mpp { - for htlc in sources.drain(..) { + } else { + for htlc in sources { let mut htlc_msat_height_data = htlc.value.to_be_bytes().to_vec(); htlc_msat_height_data.extend_from_slice(&self.best_block.read().unwrap().height.to_be_bytes()); let source = HTLCSource::PreviousHopData(htlc.prev_hop); @@ -6239,7 +6361,9 @@ where } } - fn claim_funds_from_hop, bool) -> Option>( + fn claim_funds_from_hop< + ComplFunc: FnOnce(Option, bool) -> (Option, Option) + >( &self, prev_hop: HTLCPreviousHopData, payment_preimage: PaymentPreimage, completion_action: ComplFunc, ) { @@ -6279,11 +6403,15 @@ where match fulfill_res { UpdateFulfillCommitFetch::NewClaim { htlc_value_msat, monitor_update } => { - if let Some(action) = completion_action(Some(htlc_value_msat), false) { + let (action_opt, raa_blocker_opt) = completion_action(Some(htlc_value_msat), false); + if let Some(action) = action_opt { log_trace!(logger, "Tracking monitor update completion action for channel {}: {:?}", chan_id, action); peer_state.monitor_update_blocked_actions.entry(chan_id).or_insert(Vec::new()).push(action); } + if let Some(raa_blocker) = raa_blocker_opt { + peer_state.actions_blocking_raa_monitor_updates.entry(chan_id).or_insert_with(Vec::new).push(raa_blocker); + } if !during_init { handle_new_monitor_update!(self, prev_hop.outpoint, monitor_update, peer_state_lock, peer_state, per_peer_state, chan); @@ -6301,11 +6429,16 @@ where } } UpdateFulfillCommitFetch::DuplicateClaim {} => { - let action = if let Some(action) = completion_action(None, true) { + let (action_opt, raa_blocker_opt) = completion_action(None, true); + if let Some(raa_blocker) = raa_blocker_opt { + debug_assert!(peer_state.actions_blocking_raa_monitor_updates.get(&chan_id).unwrap().contains(&raa_blocker)); + } + let action = if let Some(action) = action_opt { action } else { return; }; + mem::drop(peer_state_lock); log_trace!(logger, "Completing monitor update completion action for channel {} as claim was redundant: {:?}", @@ -6392,7 +6525,46 @@ where // `ChannelMonitor` we've provided the above update to. Instead, note that `Event`s are // generally always allowed to be duplicative (and it's specifically noted in // `PaymentForwarded`). - self.handle_monitor_update_completion_actions(completion_action(None, false)); + let (action_opt, raa_blocker_opt) = completion_action(None, false); + + if let Some(raa_blocker) = raa_blocker_opt { + let counterparty_node_id = prev_hop.counterparty_node_id.or_else(|| + // prev_hop.counterparty_node_id is always available for payments received after + // LDK 0.0.123, but for those received on 0.0.123 and claimed later, we need to + // look up the counterparty in the `action_opt`, if possible. + action_opt.as_ref().and_then(|action| + if let MonitorUpdateCompletionAction::PaymentClaimed { pending_mpp_claim, .. } = action { + pending_mpp_claim.as_ref().map(|(node_id, _, _, _)| *node_id) + } else { None } + ) + ); + if let Some(counterparty_node_id) = counterparty_node_id { + // TODO: Avoid always blocking the world for the write lock here. + let mut per_peer_state = self.per_peer_state.write().unwrap(); + let peer_state_mutex = per_peer_state.entry(counterparty_node_id).or_insert_with(|| + Mutex::new(PeerState { + channel_by_id: new_hash_map(), + inbound_channel_request_by_id: new_hash_map(), + latest_features: InitFeatures::empty(), + pending_msg_events: Vec::new(), + in_flight_monitor_updates: BTreeMap::new(), + monitor_update_blocked_actions: BTreeMap::new(), + actions_blocking_raa_monitor_updates: BTreeMap::new(), + is_connected: false, + })); + let mut peer_state = peer_state_mutex.lock().unwrap(); + + peer_state.actions_blocking_raa_monitor_updates + .entry(prev_hop.channel_id) + .or_insert_with(Vec::new) + .push(raa_blocker); + } else { + debug_assert!(false, + "RAA ChannelMonitorUpdate blockers are only set with PaymentClaimed completion actions, so we should always have a counterparty node id"); + } + } + + self.handle_monitor_update_completion_actions(action_opt); } fn finalize_claims(&self, sources: Vec) { @@ -6429,7 +6601,12 @@ where |htlc_claim_value_msat, definitely_duplicate| { let chan_to_release = if let Some(node_id) = next_channel_counterparty_node_id { - Some((node_id, next_channel_outpoint, next_channel_id, completed_blocker)) + Some(EventUnblockedChannel { + counterparty_node_id: node_id, + funding_txo: next_channel_outpoint, + channel_id: next_channel_id, + blocking_action: completed_blocker + }) } else { // We can only get `None` here if we are processing a // `ChannelMonitor`-originated event, in which case we @@ -6486,16 +6663,16 @@ where } }), "{:?}", *background_events); } - None + (None, None) } else if definitely_duplicate { if let Some(other_chan) = chan_to_release { - Some(MonitorUpdateCompletionAction::FreeOtherChannelImmediately { - downstream_counterparty_node_id: other_chan.0, - downstream_funding_outpoint: other_chan.1, - downstream_channel_id: other_chan.2, - blocking_action: other_chan.3, - }) - } else { None } + (Some(MonitorUpdateCompletionAction::FreeOtherChannelImmediately { + downstream_counterparty_node_id: other_chan.counterparty_node_id, + downstream_funding_outpoint: other_chan.funding_txo, + downstream_channel_id: other_chan.channel_id, + blocking_action: other_chan.blocking_action, + }), None) + } else { (None, None) } } else { let total_fee_earned_msat = if let Some(forwarded_htlc_value) = forwarded_htlc_value_msat { if let Some(claimed_htlc_value) = htlc_claim_value_msat { @@ -6504,7 +6681,7 @@ where } else { None }; debug_assert!(skimmed_fee_msat <= total_fee_earned_msat, "skimmed_fee_msat must always be included in total_fee_earned_msat"); - Some(MonitorUpdateCompletionAction::EmitEventAndFreeOtherChannel { + (Some(MonitorUpdateCompletionAction::EmitEventAndFreeOtherChannel { event: events::Event::PaymentForwarded { prev_channel_id: Some(prev_channel_id), next_channel_id: Some(next_channel_id), @@ -6516,7 +6693,7 @@ where outbound_amount_forwarded_msat: forwarded_htlc_value_msat, }, downstream_counterparty_and_funding_outpoint: chan_to_release, - }) + }), None) } }); }, @@ -6533,9 +6710,44 @@ where debug_assert_ne!(self.claimable_payments.held_by_thread(), LockHeldState::HeldByThread); debug_assert_ne!(self.per_peer_state.held_by_thread(), LockHeldState::HeldByThread); + let mut freed_channels = Vec::new(); + for action in actions.into_iter() { match action { - MonitorUpdateCompletionAction::PaymentClaimed { payment_hash } => { + MonitorUpdateCompletionAction::PaymentClaimed { payment_hash, pending_mpp_claim } => { + if let Some((counterparty_node_id, chan_id, htlc_id, claim_ptr)) = pending_mpp_claim { + let per_peer_state = self.per_peer_state.read().unwrap(); + per_peer_state.get(&counterparty_node_id).map(|peer_state_mutex| { + let mut peer_state = peer_state_mutex.lock().unwrap(); + let blockers_entry = peer_state.actions_blocking_raa_monitor_updates.entry(chan_id); + if let btree_map::Entry::Occupied(mut blockers) = blockers_entry { + blockers.get_mut().retain(|blocker| + if let &RAAMonitorUpdateBlockingAction::ClaimedMPPPayment { pending_claim } = &blocker { + if *pending_claim == claim_ptr { + let mut pending_claim_state_lock = pending_claim.0.lock().unwrap(); + let pending_claim_state = &mut *pending_claim_state_lock; + pending_claim_state.channels_without_preimage.retain(|(cp, outp, cid, hid)| { + if *cp == counterparty_node_id && *cid == chan_id && *hid == htlc_id { + pending_claim_state.channels_with_preimage.push((*cp, *outp, *cid)); + false + } else { true } + }); + if pending_claim_state.channels_without_preimage.is_empty() { + for (cp, outp, cid) in pending_claim_state.channels_with_preimage.iter() { + freed_channels.push((*cp, *outp, *cid, blocker.clone())); + } + } + !pending_claim_state.channels_without_preimage.is_empty() + } else { true } + } else { true } + ); + if blockers.get().is_empty() { + blockers.remove(); + } + } + }); + } + let payment = self.claimable_payments.lock().unwrap().pending_claiming_payments.remove(&payment_hash); if let Some(ClaimingPayment { amount_msat, @@ -6560,8 +6772,11 @@ where event, downstream_counterparty_and_funding_outpoint } => { self.pending_events.lock().unwrap().push_back((event, None)); - if let Some((node_id, funding_outpoint, channel_id, blocker)) = downstream_counterparty_and_funding_outpoint { - self.handle_monitor_update_release(node_id, funding_outpoint, channel_id, Some(blocker)); + if let Some(unblocked) = downstream_counterparty_and_funding_outpoint { + self.handle_monitor_update_release( + unblocked.counterparty_node_id, unblocked.funding_txo, + unblocked.channel_id, Some(unblocked.blocking_action), + ); } }, MonitorUpdateCompletionAction::FreeOtherChannelImmediately { @@ -6576,6 +6791,10 @@ where }, } } + + for (node_id, funding_outpoint, channel_id, blocker) in freed_channels { + self.handle_monitor_update_release(node_id, funding_outpoint, channel_id, Some(blocker)); + } } /// Handles a channel reentering a functional state, either due to reconnect or a monitor @@ -6586,7 +6805,7 @@ where pending_forwards: Vec<(PendingHTLCInfo, u64)>, pending_update_adds: Vec, funding_broadcastable: Option, channel_ready: Option, announcement_sigs: Option) - -> (Option<(u64, OutPoint, ChannelId, u128, Vec<(PendingHTLCInfo, u64)>)>, Option<(u64, Vec)>) { + -> (Option<(u64, Option, OutPoint, ChannelId, u128, Vec<(PendingHTLCInfo, u64)>)>, Option<(u64, Vec)>) { let logger = WithChannelContext::from(&self.logger, &channel.context, None); log_trace!(logger, "Handling channel resumption for channel {} with {} RAA, {} commitment update, {} pending forwards, {} pending update_add_htlcs, {}broadcasting funding, {} channel ready, {} announcement", &channel.context.channel_id(), @@ -6602,8 +6821,11 @@ where let mut htlc_forwards = None; if !pending_forwards.is_empty() { - htlc_forwards = Some((short_channel_id, channel.context.get_funding_txo().unwrap(), - channel.context.channel_id(), channel.context.get_user_id(), pending_forwards)); + htlc_forwards = Some(( + short_channel_id, Some(channel.context.get_counterparty_node_id()), + channel.context.get_funding_txo().unwrap(), channel.context.channel_id(), + channel.context.get_user_id(), pending_forwards + )); } let mut decode_update_add_htlcs = None; if !pending_update_adds.is_empty() { @@ -7646,15 +7868,15 @@ where } #[inline] - fn forward_htlcs(&self, per_source_pending_forwards: &mut [(u64, OutPoint, ChannelId, u128, Vec<(PendingHTLCInfo, u64)>)]) { + fn forward_htlcs(&self, per_source_pending_forwards: &mut [(u64, Option, OutPoint, ChannelId, u128, Vec<(PendingHTLCInfo, u64)>)]) { let push_forward_event = self.forward_htlcs_without_forward_event(per_source_pending_forwards); if push_forward_event { self.push_pending_forwards_ev() } } #[inline] - fn forward_htlcs_without_forward_event(&self, per_source_pending_forwards: &mut [(u64, OutPoint, ChannelId, u128, Vec<(PendingHTLCInfo, u64)>)]) -> bool { + fn forward_htlcs_without_forward_event(&self, per_source_pending_forwards: &mut [(u64, Option, OutPoint, ChannelId, u128, Vec<(PendingHTLCInfo, u64)>)]) -> bool { let mut push_forward_event = false; - for &mut (prev_short_channel_id, prev_funding_outpoint, prev_channel_id, prev_user_channel_id, ref mut pending_forwards) in per_source_pending_forwards { + for &mut (prev_short_channel_id, prev_counterparty_node_id, prev_funding_outpoint, prev_channel_id, prev_user_channel_id, ref mut pending_forwards) in per_source_pending_forwards { let mut new_intercept_events = VecDeque::new(); let mut failed_intercept_forwards = Vec::new(); if !pending_forwards.is_empty() { @@ -7673,7 +7895,9 @@ where match forward_htlcs.entry(scid) { hash_map::Entry::Occupied(mut entry) => { entry.get_mut().push(HTLCForwardInfo::AddHTLC(PendingAddHTLCInfo { - prev_short_channel_id, prev_funding_outpoint, prev_channel_id, prev_htlc_id, prev_user_channel_id, forward_info })); + prev_short_channel_id, prev_counterparty_node_id, prev_funding_outpoint, + prev_channel_id, prev_htlc_id, prev_user_channel_id, forward_info + })); }, hash_map::Entry::Vacant(entry) => { if !is_our_scid && forward_info.incoming_amt_msat.is_some() && @@ -7691,7 +7915,9 @@ where intercept_id }, None)); entry.insert(PendingAddHTLCInfo { - prev_short_channel_id, prev_funding_outpoint, prev_channel_id, prev_htlc_id, prev_user_channel_id, forward_info }); + prev_short_channel_id, prev_counterparty_node_id, prev_funding_outpoint, + prev_channel_id, prev_htlc_id, prev_user_channel_id, forward_info + }); }, hash_map::Entry::Occupied(_) => { let logger = WithContext::from(&self.logger, None, Some(prev_channel_id), Some(forward_info.payment_hash)); @@ -7699,6 +7925,7 @@ where let htlc_source = HTLCSource::PreviousHopData(HTLCPreviousHopData { short_channel_id: prev_short_channel_id, user_channel_id: Some(prev_user_channel_id), + counterparty_node_id: prev_counterparty_node_id, outpoint: prev_funding_outpoint, channel_id: prev_channel_id, htlc_id: prev_htlc_id, @@ -7718,7 +7945,9 @@ where // payments are being processed. push_forward_event |= forward_htlcs_empty && decode_update_add_htlcs_empty; entry.insert(vec!(HTLCForwardInfo::AddHTLC(PendingAddHTLCInfo { - prev_short_channel_id, prev_funding_outpoint, prev_channel_id, prev_htlc_id, prev_user_channel_id, forward_info }))); + prev_short_channel_id, prev_counterparty_node_id, prev_funding_outpoint, + prev_channel_id, prev_htlc_id, prev_user_channel_id, forward_info + }))); } } } @@ -9507,6 +9736,7 @@ where htlc_id: htlc.prev_htlc_id, incoming_packet_shared_secret: htlc.forward_info.incoming_shared_secret, phantom_shared_secret: None, + counterparty_node_id: htlc.prev_counterparty_node_id, outpoint: htlc.prev_funding_outpoint, channel_id: htlc.prev_channel_id, blinded_failure: htlc.forward_info.routing.blinded_failure(), @@ -10616,6 +10846,7 @@ impl_writeable_tlv_based!(HTLCPreviousHopData, { // Note that by the time we get past the required read for type 2 above, outpoint will be // filled in, so we can safely unwrap it here. (9, channel_id, (default_value, ChannelId::v1_from_funding_outpoint(outpoint.0.unwrap()))), + (11, counterparty_node_id, option), }); impl Writeable for ClaimableHTLC { @@ -10772,6 +11003,7 @@ impl_writeable_tlv_based!(PendingAddHTLCInfo, { // Note that by the time we get past the required read for type 6 above, prev_funding_outpoint will be // filled in, so we can safely unwrap it here. (7, prev_channel_id, (default_value, ChannelId::v1_from_funding_outpoint(prev_funding_outpoint.0.unwrap()))), + (9, prev_counterparty_node_id, option), }); impl Writeable for HTLCForwardInfo { @@ -12049,7 +12281,12 @@ where for action in actions.iter() { if let MonitorUpdateCompletionAction::EmitEventAndFreeOtherChannel { downstream_counterparty_and_funding_outpoint: - Some((blocked_node_id, _blocked_channel_outpoint, blocked_channel_id, blocking_action)), .. + Some(EventUnblockedChannel { + counterparty_node_id: blocked_node_id, + funding_txo: _, + channel_id: blocked_channel_id, + blocking_action, + }), .. } = action { if let Some(blocked_peer_state) = per_peer_state.get(blocked_node_id) { log_trace!(logger, diff --git a/lightning/src/util/ser_macros.rs b/lightning/src/util/ser_macros.rs index 255fd3ebc32..a45c5029230 100644 --- a/lightning/src/util/ser_macros.rs +++ b/lightning/src/util/ser_macros.rs @@ -1001,7 +1001,7 @@ macro_rules! _impl_writeable_tlv_based_enum_common { impl $crate::util::ser::Writeable for $st { fn write(&self, writer: &mut W) -> Result<(), $crate::io::Error> { match self { - $($st::$variant_name { $(ref $field),* } => { + $($st::$variant_name { $(ref $field, )* .. } => { let id: u8 = $variant_id; id.write(writer)?; $crate::write_tlv_fields!(writer, { @@ -1099,10 +1099,12 @@ macro_rules! impl_writeable_tlv_based_enum_upgradable { {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*} ),* $(,)* $(; - $(($tuple_variant_id: expr, $tuple_variant_name: ident)),* $(,)*)*) => { + $(($tuple_variant_id: expr, $tuple_variant_name: ident)),* $(,)*)? + $(unread_variants: $($unread_variant: ident),*)?) => { $crate::_impl_writeable_tlv_based_enum_common!($st, - $(($variant_id, $variant_name) => {$(($type, $field, $fieldty)),*}),*; - $($(($tuple_variant_id, $tuple_variant_name)),*)*); + $(($variant_id, $variant_name) => {$(($type, $field, $fieldty)),*}),* + $(, $((255, $unread_variant) => {}),*)? + ; $($(($tuple_variant_id, $tuple_variant_name)),*)?); impl $crate::util::ser::MaybeReadable for $st { fn read(reader: &mut R) -> Result, $crate::ln::msgs::DecodeError> { @@ -1126,7 +1128,9 @@ macro_rules! impl_writeable_tlv_based_enum_upgradable { $($($tuple_variant_id => { Ok(Some($st::$tuple_variant_name(Readable::read(reader)?))) }),*)* - _ if id % 2 == 1 => { + // Note that we explicitly match 255 here to reserve it for use in + // `unread_variants`. + 255|_ if id % 2 == 1 => { // Assume that a $variant_id was written, not a $tuple_variant_id, and read // the length prefix and discard the correct number of bytes. let tlv_len: $crate::util::ser::BigSize = $crate::util::ser::Readable::read(reader)?;