66import logging
77import os
88import json
9+ import base64
910from datetime import datetime , timezone
1011from typing import Optional , Dict
1112
@@ -184,23 +185,98 @@ def extract_context_from_http_event_or_context(event, lambda_context):
184185 return trace_id , parent_id , sampling_priority
185186
186187
187- def extract_context_from_sqs_event_or_context (event , lambda_context ):
188+ def create_sns_event (message ):
189+ return {
190+ "Records" : [
191+ {
192+ "EventSource" : "aws:sns" ,
193+ "EventVersion" : "1.0" ,
194+ "Sns" : message ,
195+ }
196+ ]
197+ }
198+
199+
200+ def extract_context_from_sqs_or_sns_event_or_context (event , lambda_context ):
188201 """
189202 Extract Datadog trace context from the first SQS message attributes.
190203
191204 Falls back to lambda context if no trace data is found in the SQS message attributes.
192205 """
193206 try :
194207 first_record = event ["Records" ][0 ]
195- msg_attributes = first_record .get ("messageAttributes" , {})
196- dd_json_data = msg_attributes .get ("_datadog" , {}).get ("stringValue" , r"{}" )
208+
209+ # logic to deal with SNS => SQS event
210+ if "body" in first_record :
211+ body_str = first_record .get ("body" , {})
212+ try :
213+ body = json .loads (body_str )
214+ if body .get ("Type" , "" ) == "Notification" and "TopicArn" in body :
215+ logger .debug ("Found SNS message inside SQS event" )
216+ first_record = get_first_record (create_sns_event (body ))
217+ except Exception :
218+ first_record = event ["Records" ][0 ]
219+ pass
220+
221+ msg_attributes = first_record .get (
222+ "messageAttributes" ,
223+ first_record .get ("Sns" , {}).get ("MessageAttributes" , {}),
224+ )
225+ dd_payload = msg_attributes .get ("_datadog" , {})
226+ dd_json_data = dd_payload .get ("stringValue" , dd_payload .get ("Value" , r"{}" ))
197227 dd_data = json .loads (dd_json_data )
198228 trace_id = dd_data .get (TraceHeader .TRACE_ID )
199229 parent_id = dd_data .get (TraceHeader .PARENT_ID )
200230 sampling_priority = dd_data .get (TraceHeader .SAMPLING_PRIORITY )
201231
202232 return trace_id , parent_id , sampling_priority
203- except Exception :
233+ except Exception as e :
234+ logger .debug ("The trace extractor returned with error %s" , e )
235+ return extract_context_from_lambda_context (lambda_context )
236+
237+
238+ def extract_context_from_eventbridge_event (event , lambda_context ):
239+ """
240+ Extract datadog trace context from an EventBridge message's Details.
241+ Details is often a weirdly escaped almost-JSON string. Here we have to correct for that.
242+ """
243+ try :
244+ detail = event ["detail" ]
245+ dd_context = detail .get ("_datadog" )
246+ if not dd_context :
247+ return extract_context_from_lambda_context (lambda_context )
248+ trace_id = dd_context .get (TraceHeader .TRACE_ID )
249+ parent_id = dd_context .get (TraceHeader .PARENT_ID )
250+ sampling_priority = dd_context .get (TraceHeader .SAMPLING_PRIORITY )
251+ return trace_id , parent_id , sampling_priority
252+ except Exception as e :
253+ logger .debug ("The trace extractor returned with error %s" , e )
254+ return extract_context_from_lambda_context (lambda_context )
255+
256+
257+ def extract_context_from_kinesis_event (event , lambda_context ):
258+ """
259+ Extract datadog trace context from a Kinesis Stream's base64 encoded data string
260+ """
261+ try :
262+ record = get_first_record (event )
263+ data = record .get ("kinesis" , {}).get ("data" , None )
264+ if data :
265+ b64_bytes = data .encode ("ascii" )
266+ str_bytes = base64 .b64decode (b64_bytes )
267+ data_str = str_bytes .decode ("ascii" )
268+ data_obj = json .loads (data_str )
269+ dd_ctx = data_obj .get ("_datadog" )
270+
271+ if not dd_ctx :
272+ return extract_context_from_lambda_context (lambda_context )
273+
274+ trace_id = dd_ctx .get (TraceHeader .TRACE_ID )
275+ parent_id = dd_ctx .get (TraceHeader .PARENT_ID )
276+ sampling_priority = dd_ctx .get (TraceHeader .SAMPLING_PRIORITY )
277+ return trace_id , parent_id , sampling_priority
278+ except Exception as e :
279+ logger .debug ("The trace extractor returned with error %s" , e )
204280 return extract_context_from_lambda_context (lambda_context )
205281
206282
@@ -230,6 +306,7 @@ def extract_dd_trace_context(event, lambda_context, extractor=None):
230306 """
231307 global dd_trace_context
232308 trace_context_source = None
309+ event_source = parse_event_source (event )
233310
234311 if extractor is not None :
235312 (
@@ -243,12 +320,24 @@ def extract_dd_trace_context(event, lambda_context, extractor=None):
243320 parent_id ,
244321 sampling_priority ,
245322 ) = extract_context_from_http_event_or_context (event , lambda_context )
246- elif "Records" in event :
323+ elif event_source . equals ( EventTypes . SNS ) or event_source . equals ( EventTypes . SQS ) :
247324 (
248325 trace_id ,
249326 parent_id ,
250327 sampling_priority ,
251- ) = extract_context_from_sqs_event_or_context (event , lambda_context )
328+ ) = extract_context_from_sqs_or_sns_event_or_context (event , lambda_context )
329+ elif event_source .equals (EventTypes .EVENTBRIDGE ):
330+ (
331+ trace_id ,
332+ parent_id ,
333+ sampling_priority ,
334+ ) = extract_context_from_eventbridge_event (event , lambda_context )
335+ elif event_source .equals (EventTypes .KINESIS ):
336+ (
337+ trace_id ,
338+ parent_id ,
339+ sampling_priority ,
340+ ) = extract_context_from_kinesis_event (event , lambda_context )
252341 else :
253342 trace_id , parent_id , sampling_priority = extract_context_from_lambda_context (
254343 lambda_context
@@ -556,6 +645,8 @@ def create_inferred_span_from_http_api_event(event, context):
556645
557646
558647def create_inferred_span_from_sqs_event (event , context ):
648+ trace_ctx = tracer .current_trace_context ()
649+
559650 event_record = get_first_record (event )
560651 event_source_arn = event_record ["eventSourceARN" ]
561652 queue_name = event_source_arn .split (":" )[- 1 ]
@@ -574,11 +665,37 @@ def create_inferred_span_from_sqs_event(event, context):
574665 "resource" : queue_name ,
575666 "span_type" : "web" ,
576667 }
668+ start_time = int (request_time_epoch ) / 1000
669+
670+ # logic to deal with SNS => SQS event
671+ sns_span = None
672+ if "body" in event_record :
673+ body_str = event_record .get ("body" , {})
674+ try :
675+ body = json .loads (body_str )
676+ if body .get ("Type" , "" ) == "Notification" and "TopicArn" in body :
677+ logger .debug ("Found SNS message inside SQS event" )
678+ sns_span = create_inferred_span_from_sns_event (
679+ create_sns_event (body ), context
680+ )
681+ sns_span .finish (finish_time = start_time )
682+ except Exception as e :
683+ logger .debug (
684+ "Unable to create SNS span from SQS message, with error %s" % e
685+ )
686+ pass
687+
688+ # trace context needs to be set again as it is reset
689+ # when sns_span.finish executes
690+ tracer .context_provider .activate (trace_ctx )
577691 tracer .set_tags ({"_dd.origin" : "lambda" })
578692 span = tracer .trace ("aws.sqs" , ** args )
579693 if span :
580694 span .set_tags (tags )
581- span .start = int (request_time_epoch ) / 1000
695+ span .start = start_time
696+ if sns_span :
697+ span .parent_id = sns_span .span_id
698+
582699 return span
583700
584701
@@ -594,9 +711,12 @@ def create_inferred_span_from_sns_event(event, context):
594711 "topic_arn" : topic_arn ,
595712 "message_id" : sns_message ["MessageId" ],
596713 "type" : sns_message ["Type" ],
597- "subject" : sns_message ["Subject" ],
598- "event_subscription_arn" : event_record ["EventSubscriptionArn" ],
599714 }
715+
716+ # Subject not available in SNS => SQS scenario
717+ if "Subject" in sns_message and sns_message ["Subject" ]:
718+ tags ["subject" ] = sns_message ["Subject" ]
719+
600720 InferredSpanInfo .set_tags (tags , tag_source = "self" , synchronicity = "async" )
601721 sns_dt_format = "%Y-%m-%dT%H:%M:%S.%fZ"
602722 timestamp = event_record ["Sns" ]["Timestamp" ]
@@ -644,7 +764,7 @@ def create_inferred_span_from_kinesis_event(event, context):
644764 span = tracer .trace ("aws.kinesis" , ** args )
645765 if span :
646766 span .set_tags (tags )
647- span .start = int ( request_time_epoch )
767+ span .start = request_time_epoch
648768 return span
649769
650770
@@ -662,7 +782,7 @@ def create_inferred_span_from_dynamodb_event(event, context):
662782 "event_name" : event_record ["eventName" ],
663783 "event_version" : event_record ["eventVersion" ],
664784 "stream_view_type" : dynamodb_message ["StreamViewType" ],
665- "size_bytes" : dynamodb_message ["SizeBytes" ],
785+ "size_bytes" : str ( dynamodb_message ["SizeBytes" ]) ,
666786 }
667787 InferredSpanInfo .set_tags (tags , synchronicity = "async" , tag_source = "self" )
668788 request_time_epoch = event_record ["dynamodb" ]["ApproximateCreationDateTime" ]
@@ -690,8 +810,8 @@ def create_inferred_span_from_s3_event(event, context):
690810 "bucketname" : bucket_name ,
691811 "bucket_arn" : event_record ["s3" ]["bucket" ]["arn" ],
692812 "object_key" : event_record ["s3" ]["object" ]["key" ],
693- "object_size" : event_record ["s3" ]["object" ]["size" ],
694- "object_etag" : event_record ["s3" ]["etag " ],
813+ "object_size" : str ( event_record ["s3" ]["object" ]["size" ]) ,
814+ "object_etag" : event_record ["s3" ]["object" ][ "eTag " ],
695815 }
696816 InferredSpanInfo .set_tags (tags , synchronicity = "async" , tag_source = "self" )
697817 dt_format = "%Y-%m-%dT%H:%M:%S.%fZ"
@@ -786,7 +906,7 @@ def create_function_execution_span(
786906
787907
788908class InferredSpanInfo (object ):
789- BASE_NAME = "inferred_span "
909+ BASE_NAME = "_inferred_span "
790910 SYNCHRONICITY = f"{ BASE_NAME } .synchronicity"
791911 TAG_SOURCE = f"{ BASE_NAME } .tag_source"
792912
0 commit comments