Skip to content

Commit a03b9be

Browse files
authored
Merge pull request #1 from PADL/amp-fixes
AmpCode review fixes
2 parents 7ec1cd8 + 26cad34 commit a03b9be

3 files changed

Lines changed: 47 additions & 28 deletions

File tree

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ let package = Package(
3939
),
4040
],
4141
dependencies: [
42-
.package(url: "https://github.com/PADL/IORingSwift", branch: "main"),
42+
.package(url: "https://github.com/PADL/IORingSwift", from: "1.0.0"),
4343
.package(url: "https://github.com/apple/swift-async-algorithms", from: "1.0.0"),
4444
.package(url: "https://github.com/apple/swift-system", from: "1.2.1"),
4545
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.2.0"),

Sources/NetLink/NFNetLink.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ public final class NFNLLog: Sendable {
140140

141141
public init(family: sa_family_t = sa_family_t(AF_BRIDGE), group: UInt16) throws {
142142
_socket = try NLSocket(protocol: NETLINK_NETFILTER)
143-
_log = NLObject(consumingObj: nfnl_log_alloc())
143+
guard let logObj = nfnl_log_alloc() else { throw NLError.noMemory }
144+
_log = NLObject(consumingObj: logObj)
144145
try throwingNLError {
145146
nfnl_log_pf_bind(_socket._sk, UInt8(family))
146147
}

Sources/NetLink/NetLink.swift

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,24 @@ Sendable, Equatable, Hashable, CustomStringConvertible {
7373
try withUnsafeMutablePointer(to: &obj) { objRef in
7474
_ = try throwingNLError {
7575
nl_msg_parse(msg, { obj, objRef in
76-
nl_object_get(obj)
77-
objRef!
76+
let ptr = objRef!
7877
.withMemoryRebound(
79-
to: OpaquePointer.self,
78+
to: OpaquePointer?.self,
8079
capacity: 1
81-
) { objRef in
82-
objRef.pointee = obj!
83-
}
80+
) { $0 }
81+
if let existing = ptr.pointee {
82+
nl_object_put(existing)
83+
}
84+
nl_object_get(obj)
85+
ptr.pointee = obj
8486
}, objRef)
8587
}
8688
}
8789

90+
guard obj != nil else {
91+
throw NLError.invalidArgument
92+
}
93+
8894
self.init(obj: obj, constructFromObject: constructFromObject)
8995
nl_object_put(obj)
9096
}
@@ -222,15 +228,15 @@ private func NLSocket_ErrCB(
222228
let hdr = err.pointee.msg
223229
debugPrint("NLSocket_ErrCB: error \(err.pointee)")
224230
nlSocket.yield(sequence: hdr.nlmsg_seq, with: Result.failure(Errno(rawValue: -err.pointee.error)))
225-
return err.pointee.error
231+
return CInt(NL_SKIP.rawValue)
226232
}
227233

228234
public final class NLSocket: @unchecked
229235
Sendable {
230236
private typealias Continuation = CheckedContinuation<NLObjectConstructible, Error>
231237
private typealias Stream = AsyncThrowingStream<NLObjectConstructible, Error>
232238
private typealias Ack = CheckedContinuation<(), Error>
233-
public typealias Channel = AsyncThrowingChannel<NLObjectConstructible, Error>
239+
public typealias NotificationStream = AsyncThrowingStream<NLObjectConstructible, Error>
234240

235241
private enum _Request {
236242
case continuation(Continuation)
@@ -248,12 +254,18 @@ Sendable {
248254
}
249255

250256
let _sk: OpaquePointer!
257+
let _queue = DispatchQueue(label: "NLSocket")
251258
private let _readSource: any DispatchSourceRead
252259
private let _requests = Mutex<[UInt32: _Request]>([:])
253260

254-
public let notifications = Channel()
261+
private let _notificationsContinuation: NotificationStream.Continuation
262+
public let notifications: NotificationStream
255263

256264
public init(protocol: Int32) throws {
265+
var continuation: NotificationStream.Continuation!
266+
notifications = NotificationStream(bufferingPolicy: .unbounded) { continuation = $0 }
267+
_notificationsContinuation = continuation
268+
257269
guard let sk = nl_socket_alloc() else { throw NLError.noMemory }
258270
nl_socket_disable_seq_check(sk)
259271
_sk = sk
@@ -266,7 +278,7 @@ Sendable {
266278
let fd = nl_socket_get_fd(sk)
267279
precondition(fd >= 0)
268280

269-
_readSource = DispatchSource.makeReadSource(fileDescriptor: fd, queue: .main)
281+
_readSource = DispatchSource.makeReadSource(fileDescriptor: fd, queue: _queue)
270282
_readSource.setEventHandler(handler: onReadReady)
271283

272284
nl_socket_modify_cb(
@@ -302,6 +314,7 @@ Sendable {
302314

303315
deinit {
304316
_readSource.cancel()
317+
_notificationsContinuation.finish()
305318
nl_socket_free(_sk)
306319
}
307320

@@ -344,13 +357,15 @@ Sendable {
344357
}
345358

346359
public func useNextSequenceNumber() -> UInt32 {
347-
var nextSequenceNumber: UInt32
360+
_queue.sync {
361+
var nextSequenceNumber: UInt32
348362

349-
repeat {
350-
nextSequenceNumber = nl_socket_use_seq(_sk)
351-
} while nextSequenceNumber == 0
363+
repeat {
364+
nextSequenceNumber = nl_socket_use_seq(_sk)
365+
} while nextSequenceNumber == 0
352366

353-
return nextSequenceNumber
367+
return nextSequenceNumber
368+
}
354369
}
355370

356371
private func _lookup(sequence: UInt32, forceRemove: Bool) -> _Request? {
@@ -405,13 +420,7 @@ Sendable {
405420
}
406421
}
407422
} else {
408-
Task {
409-
do {
410-
try await notifications.send(result.get())
411-
} catch {
412-
notifications.fail(error)
413-
}
414-
}
423+
_notificationsContinuation.yield(with: result)
415424
}
416425
}
417426

@@ -466,7 +475,12 @@ Sendable {
466475
}
467476
stream = _stream
468477
}
469-
try message.send(on: self)
478+
do {
479+
try message.send(on: self)
480+
} catch {
481+
_requests.withLock { $0.removeValue(forKey: sequence) }
482+
throw error
483+
}
470484
return stream
471485
}
472486
}
@@ -764,8 +778,10 @@ struct NLMessage: ~Copyable {
764778
}
765779

766780
func append(opaque value: UnsafePointer<some Any>) throws {
767-
_ = try withUnsafeBytes(of: value.pointee) { value in
768-
try append(Array(value))
781+
try withUnsafeBytes(of: value.pointee) { bytes in
782+
try throwingNLError {
783+
nlmsg_append(_msg, UnsafeMutableRawPointer(mutating: bytes.baseAddress), bytes.count, CInt(NLMSG_ALIGNTO))
784+
}
769785
}
770786
}
771787

@@ -858,7 +874,9 @@ struct NLMessage: ~Copyable {
858874
}
859875

860876
func send(on socket: NLSocket) throws {
861-
try throwingNLError { nl_send_auto(socket._sk, _msg) }
877+
try socket._queue.sync {
878+
try throwingNLError { nl_send_auto(socket._sk, _msg) }
879+
}
862880
}
863881

864882
deinit {

0 commit comments

Comments
 (0)