11#!/usr/bin/env python
22#Lexi Morrissey, Mahony Lab @ Pennsylvania State University
3- #Last updated 03.01.2023
3+ #Last updated 04.22.2024
44#Contains methods for read allocation procedure of Allo.
55
66from Allo import predictPeak
1717import sys
1818import multiprocessing
1919import re
20+ import absl .logging
21+ absl .logging .set_verbosity (absl .logging .ERROR )
22+ import contextlib , io
2023
2124#Add reads to UMR dictionary
2225def addToDict (tempFile , genLand , spliceD , seq ):
@@ -31,8 +34,7 @@ def addToDict(tempFile, genLand, spliceD, seq):
3134 count = count + 1
3235 if seq == 0 and not (count % 2 ) == 0 :
3336 continue
34- #Getting closest 5bp window
35- l [3 ] = int (l [3 ]))
37+ l [3 ] = int (l [3 ])
3638 #Add to dictionary
3739 key = l [2 ] + ";" + str (l [3 ])
3840 if key in genLand :
@@ -65,26 +67,24 @@ def addToDict(tempFile, genLand, spliceD, seq):
6567
6668#Used to get counts in regions around multimapped reads
6769def getArray (read , winSize , genLand , spliceD ):
68- bin = int (winSize / 100 )
6970 #Spliced array using junction info
7071 if spliceD :
7172 #Parsing cigar string
7273 chars = re .findall ('[A-Za-z]+' , read [5 ])
7374 num = re .findall ('\d+\.\d+|\d+' , read [5 ])
7475 gap_loc = {}
75- start = int (round ( int ( read [3 ]) / bin ) * bin )
76+ start = int (read [3 ])
7677 chr = read [2 ]
7778 r_end = start #Get information about end of read to use for splice variants
7879 #print("cigar: " + read[5], flush=True)
7980 for i in range (0 ,len (chars )):
8081 if chars [i ]== "M" or chars [i ]== "D" or chars [i ]== "=" or chars [i ]== "X" :
8182 r_end = r_end + int (num [i ])
8283 elif chars [i ]== "N" :
83- gap_loc [int ( round (( int ( r_end ) + 4 ) / bin ) * bin ) ]= int (round (( int ( num [i ]) - 4 ) / bin ) * bin )
84+ gap_loc [r_end + 1 ]= int (num [i ])
8485 r_end = r_end + int (num [i ])
85- r_end = int (round (int (r_end )/ bin )* bin )
8686 array = []
87- k = start - bin
87+ k = start
8888 l = 0
8989 #Upstream counts
9090 while l <= math .floor (winSize / 2 ):
@@ -98,8 +98,8 @@ def getArray(read, winSize, genLand, spliceD):
9898 array .insert (0 ,genLand [key ])
9999 else :
100100 array .insert (0 ,0 )
101- k -= bin
102- l += bin
101+ k -= 1
102+ l += 1
103103 start_pos = k
104104 #Downstream counts
105105 k = start
@@ -116,8 +116,8 @@ def getArray(read, winSize, genLand, spliceD):
116116 array .append (genLand [key ])
117117 else :
118118 array .append (0 )
119- k += bin
120- l += bin
119+ k += 1
120+ l += 1
121121 #print("down2: " + str(k), flush=True)
122122 key = chr + ";" + str (k )
123123 if key in spliceD and spliceD [key ] > 0 :
@@ -127,14 +127,13 @@ def getArray(read, winSize, genLand, spliceD):
127127 array .append (genLand [key ])
128128 else :
129129 array .append (0 )
130- k += bin
131- l += bin
130+ k += 1
131+ l += 1
132132 #Non-spliced array
133133 else :
134134 array = []
135- pos = int (read [3 ])
136- k = pos - math .floor (winSize / 2 )
137- while k < int (pos )+ math .floor (winSize / 2 ):
135+ pos = round (int (read [3 ])/ 100 )* 100
136+ for k in range (pos - math .floor (winSize / 2 ),int (pos )+ math .floor (winSize / 2 )):
138137 key = read [2 ] + ";" + str (k )
139138 #Seeing if current pos in genetic landscape
140139 if key in genLand :
@@ -155,34 +154,28 @@ def readAssign(rBlock, samOut, winSize, genLand, model, cnn_scores, rc, rmz, mod
155154 allZ = True #seeing if all zero regions
156155 for i in rBlock :
157156 #Find closest 100 window, use that score instead if it's already been assigned, saves time
158- pos = i [2 ]+ str (round (int (i [3 ])/ 100 )* 100 )
159- if pos in cnn_scores :
160- #scores_nn.append(cnn_scores[pos])
157+ pos = i [2 ]+ str (i [3 ])
158+ countArray = getArray (i , winSize , genLand , spliceD )
159+ s = sum (countArray )
160+ if s > 0 :
161161 allZ = False
162- else :
163- countArray = getArray (i , winSize , genLand , spliceD )
164- s = sum (countArray )
165- if s > 0 :
166- allZ = False
167- #Allocation options
168- if rc == 1 :
169- if s == 0 :
170- scores_rc .append (1 )
171- else :
172- scores_rc .append (s + 1 )
173- continue
174- if rc == 2 :
175- scores_rc .append (1 )
176- continue
177- #Use no read score if zero region
162+ #Allocation options
163+ if rc == 1 :
178164 if s == 0 :
179- scores_nn .append (0.0012 * (s + 1 ))
180- elif s <= 5 :
181- scores_nn .append (0.0062 * (s + 1 ))
165+ scores_rc .append (1 )
182166 else :
183- nn = predictPeak .predictNN (countArray , winSize , model )
184- scores_nn .append (nn * (s + 1 ))
185- cnn_scores [pos ] = (nn * (s + 1 ))
167+ scores_rc .append (s + 1 )
168+ continue
169+ if rc == 2 :
170+ scores_rc .append (1 )
171+ continue
172+ #Use no read score if zero region
173+ if s == 0 :
174+ scores_nn .append (0.0012 * (s + 1 ))
175+ else :
176+ nn = predictPeak .predictNN (countArray , winSize , model )
177+ scores_nn .append (nn * (s + 1 ))
178+ cnn_scores [pos ] = (nn * (s + 1 ))
186179
187180 #Removing reads that mapped to all zero regions
188181 if allZ and rmz == 1 :
@@ -382,21 +375,20 @@ def parseUniq(tempFile, winSize, cnn_scores, AS, rc, keep):
382375def parseMulti (tempFile , winSize , genLand , modelName , cnn_scores , rc , keep , rmz , maxa , spliceD ):
383376 numLoc = [0 ,0 ] #Keep info on average number of places read maps to
384377 #Getting trained CNN
385- try :
386- json_file = open (modelName + '.json' , 'r' )
387- loaded_model_json = json_file .read ()
388- json_file .close ()
389- model = tf .keras .models .model_from_json (loaded_model_json )
390- model .load_weights (modelName + '.h5' )
391- if "mixed" in modelName :
392- modelName = 1
393- else :
394- modelName = 0
395- except :
396- print ("Model loading error" , flush = True )
397- print ("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11" )
398- sys .exit (0 )
399-
378+ if rc == 0 :
379+ try :
380+ json_file = open (modelName + '.json' , 'r' )
381+ loaded_model_json = json_file .read ()
382+ json_file .close ()
383+ model = tf .keras .models .model_from_json (loaded_model_json )
384+ model .load_weights (modelName + '.h5' )
385+ model = LiteModel .from_keras_model (model )
386+ except :
387+ print ("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11" )
388+ sys .exit (0 )
389+ else :
390+ model = None
391+ modelName = None
400392 #Exception that causes errors
401393 if os .stat (tempFile + "MM" ).st_size == 0 :
402394 return numLoc
@@ -737,8 +729,6 @@ def readAssignPE(rBlock, rBlock2, samOut, winSize, genLand, model, cnn_scores, r
737729 #Use no read score if zero region
738730 if s == 0 :
739731 scores_nn .append (0.0012 * (s + 1 ))
740- elif s <= 5 :
741- scores_nn .append (0.0062 * (s + 1 ))
742732 else :
743733 nn = predictPeak .predictNN (countArray , winSize , model )
744734 scores_nn .append (nn * (s + 1 ))
@@ -782,20 +772,24 @@ def readAssignPE(rBlock, rBlock2, samOut, winSize, genLand, model, cnn_scores, r
782772def parseMultiPE (tempFile , winSize , genLand , modelName , cnn_scores , rc , keep , rmz , maxa , spliceD ):
783773 numLoc = [0 ,0 ] #Retain info on number of mapping sites
784774 #Getting trained CNN and making sure there is a compatible tensorflow installed
785- try :
786- json_file = open (modelName + '.json' , 'r' )
787- loaded_model_json = json_file .read ()
788- json_file .close ()
789- model = tf .keras .models .model_from_json (loaded_model_json )
790- model .load_weights (modelName + '.h5' )
791- if "mixed" in modelName :
792- modelName = 1
793- else :
794- modelName = 0
795- except :
796- print ("PE model load error" )
797- print ("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11" )
798- sys .exit (0 )
775+ if rc == 0 :
776+ try :
777+ json_file = open (modelName + '.json' , 'r' )
778+ loaded_model_json = json_file .read ()
779+ json_file .close ()
780+ model = tf .keras .models .model_from_json (loaded_model_json )
781+ model .load_weights (modelName + '.h5' )
782+ model = LiteModel .from_keras_model (model )
783+ if "mixed" in modelName :
784+ modelName = 1
785+ else :
786+ modelName = 0
787+ except :
788+ print ("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11" )
789+ sys .exit (0 )
790+ else :
791+ model = None
792+ modelName = None
799793
800794 #Exception that causes errors
801795 if os .stat (tempFile + "MM" ).st_size == 0 :
@@ -878,3 +872,46 @@ def parseMultiPE(tempFile, winSize, genLand, modelName, cnn_scores, rc, keep, rm
878872 AL .close ()
879873
880874 return numLoc
875+
876+
877+ #Class to speed up tensorflow prediction
878+ #https://micwurm.medium.com/using-tensorflow-lite-to-speed-up-predictions-a3954886eb98
879+ class LiteModel :
880+ @classmethod
881+ def from_file (cls , model_path ):
882+ return LiteModel (tf .lite .Interpreter (model_path = model_path ))
883+
884+ @classmethod
885+ def from_keras_model (cls , kmodel ):
886+ converter = tf .lite .TFLiteConverter .from_keras_model (kmodel )
887+ tflite_model = converter .convert ()
888+ return LiteModel (tf .lite .Interpreter (model_content = tflite_model ))
889+
890+ def __init__ (self , interpreter ):
891+ self .interpreter = interpreter
892+ self .interpreter .allocate_tensors ()
893+ input_det = self .interpreter .get_input_details ()[0 ]
894+ output_det = self .interpreter .get_output_details ()[0 ]
895+ self .input_index = input_det ["index" ]
896+ self .output_index = output_det ["index" ]
897+ self .input_shape = input_det ["shape" ]
898+ self .output_shape = output_det ["shape" ]
899+ self .input_dtype = input_det ["dtype" ]
900+ self .output_dtype = output_det ["dtype" ]
901+
902+ def predict (self , inp ):
903+ inp = inp .astype (self .input_dtype )
904+ count = inp .shape [0 ]
905+ out = np .zeros ((count , self .output_shape [1 ]), dtype = self .output_dtype )
906+ for i in range (count ):
907+ self .interpreter .set_tensor (self .input_index , inp [i :i + 1 ])
908+ self .interpreter .invoke ()
909+ out [i ] = self .interpreter .get_tensor (self .output_index )[0 ]
910+ return out
911+
912+ def predict_single (self , inp ):
913+ inp = np .array ([inp ], dtype = self .input_dtype )
914+ self .interpreter .set_tensor (self .input_index , inp )
915+ self .interpreter .invoke ()
916+ out = self .interpreter .get_tensor (self .output_index )
917+ return out [0 ]
0 commit comments