Skip to content

Commit 3c754bd

Browse files
Alexis MorrisseyAlexis Morrissey
authored andcommitted
Splice beta with tf lite
1 parent bc3fd48 commit 3c754bd

9 files changed

Lines changed: 139 additions & 79 deletions

File tree

Allo/.projectignore

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# This file contains a list of match patterns that instructs
2+
# anaconda-project to exclude certain files or directories when
3+
# building a project archive. The file format is a simplfied
4+
# version of Git's .gitignore file format. In fact, if the
5+
# project is hosted in a Git repository, these patterns can be
6+
# merged into the .gitignore file and this file removed.
7+
# See the anaconda-project documentation for more details.
8+
9+
# Python caching
10+
*.pyc
11+
*.pyd
12+
*.pyo
13+
__pycache__/
14+
15+
# Jupyter & Spyder stuff
16+
.ipynb_checkpoints/
17+
.Trash-*/
18+
/.spyderproject
260 Bytes
Binary file not shown.
14.7 KB
Binary file not shown.
1.25 KB
Binary file not shown.

Allo/allo

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ if __name__ == '__main__':
111111
rc = 2
112112
else:
113113
rc = 0
114+
print("Using neural network...", flush=True)
114115
#Keep unmapped reads
115116
if args.keep_unmap is not None:
116117
keep = 1
@@ -264,7 +265,7 @@ if __name__ == '__main__':
264265
shutil.rmtree(allo_dir)
265266
sys.exit(0)
266267

267-
print("Parsing finished!\n")
268+
print("Parsing finished!\n", flush=True)
268269

269270
#PHASE II: Parsing multi-mapped reads
270271
info2 = Parallel(n_jobs=thr)(delayed(allocation.parseMulti)(i, winSize, genLand, m, cnn_scores, rc, keep, rmz, maxa, spliceD) for i in tempList)
@@ -349,7 +350,7 @@ if __name__ == '__main__':
349350
shutil.rmtree(allo_dir)
350351
sys.exit(0)
351352

352-
print("Parsing finished!\n")
353+
print("Parsing finished!\n", flush=True)
353354

354355
#PHASE II: Parsing multi-mapped reads
355356
info2 = Parallel(n_jobs=thr)(delayed(allocation.parseMultiPE)(i, winSize, genLand, m, cnn_scores, rc, keep, rmz, maxa, spliceD) for i in tempList)

Allo/allocation.py

Lines changed: 110 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
#Lexi Morrissey, Mahony Lab @ Pennsylvania State University
3-
#Last updated 03.01.2023
3+
#Last updated 04.22.2024
44
#Contains methods for read allocation procedure of Allo.
55

66
from Allo import predictPeak
@@ -17,6 +17,9 @@
1717
import sys
1818
import multiprocessing
1919
import re
20+
import absl.logging
21+
absl.logging.set_verbosity(absl.logging.ERROR)
22+
import contextlib, io
2023

2124
#Add reads to UMR dictionary
2225
def addToDict(tempFile, genLand, spliceD, seq):
@@ -31,8 +34,7 @@ def addToDict(tempFile, genLand, spliceD, seq):
3134
count = count + 1
3235
if seq == 0 and not (count % 2) == 0:
3336
continue
34-
#Getting closest 5bp window
35-
l[3] = int(l[3]))
37+
l[3] = int(l[3])
3638
#Add to dictionary
3739
key = l[2] + ";" + str(l[3])
3840
if key in genLand:
@@ -65,26 +67,24 @@ def addToDict(tempFile, genLand, spliceD, seq):
6567

6668
#Used to get counts in regions around multimapped reads
6769
def getArray(read, winSize, genLand, spliceD):
68-
bin = int(winSize/100)
6970
#Spliced array using junction info
7071
if spliceD:
7172
#Parsing cigar string
7273
chars = re.findall('[A-Za-z]+', read[5])
7374
num = re.findall('\d+\.\d+|\d+', read[5])
7475
gap_loc = {}
75-
start = int(round(int(read[3])/bin)*bin)
76+
start = int(read[3])
7677
chr = read[2]
7778
r_end = start #Get information about end of read to use for splice variants
7879
#print("cigar: " + read[5], flush=True)
7980
for i in range(0,len(chars)):
8081
if chars[i]=="M" or chars[i]=="D" or chars[i]=="=" or chars[i]=="X":
8182
r_end = r_end + int(num[i])
8283
elif chars[i]=="N":
83-
gap_loc[int(round((int(r_end)+4)/bin)*bin)]=int(round((int(num[i])-4)/bin)*bin)
84+
gap_loc[r_end+1]=int(num[i])
8485
r_end = r_end + int(num[i])
85-
r_end = int(round(int(r_end)/bin)*bin)
8686
array = []
87-
k = start-bin
87+
k = start
8888
l = 0
8989
#Upstream counts
9090
while l <= math.floor(winSize/2):
@@ -98,8 +98,8 @@ def getArray(read, winSize, genLand, spliceD):
9898
array.insert(0,genLand[key])
9999
else:
100100
array.insert(0,0)
101-
k -= bin
102-
l += bin
101+
k -= 1
102+
l += 1
103103
start_pos = k
104104
#Downstream counts
105105
k = start
@@ -116,8 +116,8 @@ def getArray(read, winSize, genLand, spliceD):
116116
array.append(genLand[key])
117117
else:
118118
array.append(0)
119-
k += bin
120-
l += bin
119+
k += 1
120+
l += 1
121121
#print("down2: " + str(k), flush=True)
122122
key = chr + ";" + str(k)
123123
if key in spliceD and spliceD[key] > 0:
@@ -127,14 +127,13 @@ def getArray(read, winSize, genLand, spliceD):
127127
array.append(genLand[key])
128128
else:
129129
array.append(0)
130-
k += bin
131-
l += bin
130+
k += 1
131+
l += 1
132132
#Non-spliced array
133133
else:
134134
array = []
135-
pos = int(read[3])
136-
k = pos-math.floor(winSize/2)
137-
while k < int(pos)+math.floor(winSize/2):
135+
pos = round(int(read[3])/100)*100
136+
for k in range (pos-math.floor(winSize/2),int(pos)+math.floor(winSize/2)):
138137
key = read[2] + ";" + str(k)
139138
#Seeing if current pos in genetic landscape
140139
if key in genLand:
@@ -155,34 +154,28 @@ def readAssign(rBlock, samOut, winSize, genLand, model, cnn_scores, rc, rmz, mod
155154
allZ = True #seeing if all zero regions
156155
for i in rBlock:
157156
#Find closest 100 window, use that score instead if it's already been assigned, saves time
158-
pos = i[2]+str(round(int(i[3])/100)*100)
159-
if pos in cnn_scores:
160-
#scores_nn.append(cnn_scores[pos])
157+
pos = i[2]+str(i[3])
158+
countArray = getArray(i, winSize, genLand, spliceD)
159+
s = sum(countArray)
160+
if s > 0:
161161
allZ = False
162-
else:
163-
countArray = getArray(i, winSize, genLand, spliceD)
164-
s = sum(countArray)
165-
if s > 0:
166-
allZ = False
167-
#Allocation options
168-
if rc == 1:
169-
if s == 0:
170-
scores_rc.append(1)
171-
else:
172-
scores_rc.append(s+1)
173-
continue
174-
if rc == 2:
175-
scores_rc.append(1)
176-
continue
177-
#Use no read score if zero region
162+
#Allocation options
163+
if rc == 1:
178164
if s == 0:
179-
scores_nn.append(0.0012*(s+1))
180-
elif s <= 5:
181-
scores_nn.append(0.0062*(s+1))
165+
scores_rc.append(1)
182166
else:
183-
nn = predictPeak.predictNN(countArray, winSize, model)
184-
scores_nn.append(nn*(s+1))
185-
cnn_scores[pos] = (nn*(s+1))
167+
scores_rc.append(s+1)
168+
continue
169+
if rc == 2:
170+
scores_rc.append(1)
171+
continue
172+
#Use no read score if zero region
173+
if s == 0:
174+
scores_nn.append(0.0012*(s+1))
175+
else:
176+
nn = predictPeak.predictNN(countArray, winSize, model)
177+
scores_nn.append(nn*(s+1))
178+
cnn_scores[pos] = (nn*(s+1))
186179

187180
#Removing reads that mapped to all zero regions
188181
if allZ and rmz == 1:
@@ -382,21 +375,20 @@ def parseUniq(tempFile, winSize, cnn_scores, AS, rc, keep):
382375
def parseMulti(tempFile, winSize, genLand, modelName, cnn_scores, rc, keep, rmz, maxa, spliceD):
383376
numLoc = [0,0] #Keep info on average number of places read maps to
384377
#Getting trained CNN
385-
try:
386-
json_file = open(modelName+'.json', 'r')
387-
loaded_model_json = json_file.read()
388-
json_file.close()
389-
model = tf.keras.models.model_from_json(loaded_model_json)
390-
model.load_weights(modelName+'.h5')
391-
if "mixed" in modelName:
392-
modelName = 1
393-
else:
394-
modelName = 0
395-
except:
396-
print("Model loading error", flush=True)
397-
print("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11")
398-
sys.exit(0)
399-
378+
if rc == 0:
379+
try:
380+
json_file = open(modelName+'.json', 'r')
381+
loaded_model_json = json_file.read()
382+
json_file.close()
383+
model = tf.keras.models.model_from_json(loaded_model_json)
384+
model.load_weights(modelName+'.h5')
385+
model = LiteModel.from_keras_model(model)
386+
except:
387+
print("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11")
388+
sys.exit(0)
389+
else:
390+
model = None
391+
modelName = None
400392
#Exception that causes errors
401393
if os.stat(tempFile+"MM").st_size == 0:
402394
return numLoc
@@ -737,8 +729,6 @@ def readAssignPE(rBlock, rBlock2, samOut, winSize, genLand, model, cnn_scores, r
737729
#Use no read score if zero region
738730
if s == 0:
739731
scores_nn.append(0.0012*(s+1))
740-
elif s <= 5:
741-
scores_nn.append(0.0062*(s+1))
742732
else:
743733
nn = predictPeak.predictNN(countArray, winSize, model)
744734
scores_nn.append(nn*(s+1))
@@ -782,20 +772,24 @@ def readAssignPE(rBlock, rBlock2, samOut, winSize, genLand, model, cnn_scores, r
782772
def parseMultiPE(tempFile, winSize, genLand, modelName, cnn_scores, rc, keep, rmz, maxa, spliceD):
783773
numLoc = [0,0] #Retain info on number of mapping sites
784774
#Getting trained CNN and making sure there is a compatible tensorflow installed
785-
try:
786-
json_file = open(modelName+'.json', 'r')
787-
loaded_model_json = json_file.read()
788-
json_file.close()
789-
model = tf.keras.models.model_from_json(loaded_model_json)
790-
model.load_weights(modelName+'.h5')
791-
if "mixed" in modelName:
792-
modelName = 1
793-
else:
794-
modelName = 0
795-
except:
796-
print("PE model load error")
797-
print("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11")
798-
sys.exit(0)
775+
if rc == 0:
776+
try:
777+
json_file = open(modelName+'.json', 'r')
778+
loaded_model_json = json_file.read()
779+
json_file.close()
780+
model = tf.keras.models.model_from_json(loaded_model_json)
781+
model.load_weights(modelName+'.h5')
782+
model = LiteModel.from_keras_model(model)
783+
if "mixed" in modelName:
784+
modelName = 1
785+
else:
786+
modelName = 0
787+
except:
788+
print("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11")
789+
sys.exit(0)
790+
else:
791+
model = None
792+
modelName = None
799793

800794
#Exception that causes errors
801795
if os.stat(tempFile+"MM").st_size == 0:
@@ -878,3 +872,46 @@ def parseMultiPE(tempFile, winSize, genLand, modelName, cnn_scores, rc, keep, rm
878872
AL.close()
879873

880874
return numLoc
875+
876+
877+
#Class to speed up tensorflow prediction
878+
#https://micwurm.medium.com/using-tensorflow-lite-to-speed-up-predictions-a3954886eb98
879+
class LiteModel:
880+
@classmethod
881+
def from_file(cls, model_path):
882+
return LiteModel(tf.lite.Interpreter(model_path=model_path))
883+
884+
@classmethod
885+
def from_keras_model(cls, kmodel):
886+
converter = tf.lite.TFLiteConverter.from_keras_model(kmodel)
887+
tflite_model = converter.convert()
888+
return LiteModel(tf.lite.Interpreter(model_content=tflite_model))
889+
890+
def __init__(self, interpreter):
891+
self.interpreter = interpreter
892+
self.interpreter.allocate_tensors()
893+
input_det = self.interpreter.get_input_details()[0]
894+
output_det = self.interpreter.get_output_details()[0]
895+
self.input_index = input_det["index"]
896+
self.output_index = output_det["index"]
897+
self.input_shape = input_det["shape"]
898+
self.output_shape = output_det["shape"]
899+
self.input_dtype = input_det["dtype"]
900+
self.output_dtype = output_det["dtype"]
901+
902+
def predict(self, inp):
903+
inp = inp.astype(self.input_dtype)
904+
count = inp.shape[0]
905+
out = np.zeros((count, self.output_shape[1]), dtype=self.output_dtype)
906+
for i in range(count):
907+
self.interpreter.set_tensor(self.input_index, inp[i:i+1])
908+
self.interpreter.invoke()
909+
out[i] = self.interpreter.get_tensor(self.output_index)[0]
910+
return out
911+
912+
def predict_single(self, inp):
913+
inp = np.array([inp], dtype=self.input_dtype)
914+
self.interpreter.set_tensor(self.input_index, inp)
915+
self.interpreter.invoke()
916+
out = self.interpreter.get_tensor(self.output_index)
917+
return out[0]

Allo/predictPeak.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
#Lexi Morrissey, Mahony Lab @ Pennsylvania State University
3-
#Last updated 03.01.2023
3+
#Last updated 04.22.2024
44
#Contains method for predicting whether area should receive multimapped reads via pre-trained CNN in Allo.
55

66
import os
@@ -12,6 +12,7 @@
1212
import os
1313
tf.config.run_functions_eagerly(False)
1414
import math
15+
import sys
1516

1617

1718
def predictNN(counts, winSize, model):
@@ -35,9 +36,11 @@ def predictNN(counts, winSize, model):
3536
pic[binned[i],i] = 1
3637

3738
try:
38-
pred = model(pic.reshape(-1,100,100), training=False)
39-
return pred.numpy()[0][0]
39+
#pred = model(pic.reshape(-1,100,100), training=False)
40+
pred = model.predict(pic.reshape(-1,100,100))
41+
#return pred.numpy()[0][0]
42+
return pred[0][0]
4043

4144
except:
42-
print("Could not predict with Tensorflow model :( Allo was written with Tensorflow version 2.11")
45+
print("Could not predict with Tensorflow model :( Allo was written with Tensorflow version 2.11", flush=True)
4346
sys.exit(0)

Allo/rna.h5

279 KB
Binary file not shown.

Allo/rna.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"module": "keras.layers", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 100, 100], "dtype": "float32", "sparse": false, "ragged": false, "name": "conv1d_input"}, "registered_name": null}, {"module": "keras.layers", "class_name": "Conv1D", "config": {"name": "conv1d", "trainable": true, "dtype": "float32", "batch_input_shape": [null, 100, 100], "filters": 2, "kernel_size": [64], "strides": [1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1], "groups": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 100, 100]}}, {"module": "keras.layers", "class_name": "AveragePooling1D", "config": {"name": "average_pooling1d", "trainable": true, "dtype": "float32", "strides": [2], "pool_size": [2], "padding": "valid", "data_format": "channels_last"}, "registered_name": null, "build_config": {"input_shape": [null, 100, 2]}}, {"module": "keras.layers", "class_name": "Conv1D", "config": {"name": "conv1d_1", "trainable": true, "dtype": "float32", "filters": 2, "kernel_size": [32], "strides": [1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1], "groups": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 50, 2]}}, {"module": "keras.layers", "class_name": "Conv1D", "config": {"name": "conv1d_2", "trainable": true, "dtype": "float32", "filters": 2, "kernel_size": [32], "strides": [1], "padding": "same", "data_format": "channels_last", "dilation_rate": [1], "groups": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 50, 2]}}, {"module": "keras.layers", "class_name": "Dropout", "config": {"name": "dropout", "trainable": true, "dtype": "float32", "rate": 0.5, "noise_shape": null, "seed": null}, "registered_name": null, "build_config": {"input_shape": [null, 50, 2]}}, {"module": "keras.layers", "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "registered_name": null, "build_config": {"input_shape": [null, 50, 2]}}, {"module": "keras.layers", "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 512, "activation": "relu", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 100]}}, {"module": "keras.layers", "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 512]}}]}, "keras_version": "2.15.0", "backend": "tensorflow"}

0 commit comments

Comments
 (0)