Skip to content

Commit 34b76b8

Browse files
author
Alexis Morrissey
committed
Adding new cnn
1 parent 59bb4e8 commit 34b76b8

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

Allo/allo

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ parser.add_argument('-seq', type=str, nargs=1, help='Single-end or paired-end se
1414
choices=['pe','se'], required=True, dest='seq')
1515
parser.add_argument('-o', type=str, nargs=1, help='Output file name', dest='outfile', default=None)
1616
parser.add_argument('--mixed', help='Use CNN trained on a dataset with mixed peaks, narrow by default', action='store_true', default=None)
17+
parser.add_argument('--rna', help='Use CNN trained on a dataset with mixed peaks, narrow by default', action='store_true', default=None)
1718
parser.add_argument('-p', type=int, nargs=1, help='Number of processes, 1 by default', dest='processes', default=None)
1819
parser.add_argument('-max', type=int, nargs=1, help='Maximum value for number of locations a read can map', dest='maxlocations', default=None)
1920
parser.add_argument('--keep-unmap', help='Keep unmapped reads and reads that include N in their sequence', action='store_true', default=None)
@@ -84,6 +85,10 @@ if __name__ == '__main__':
8485
d = os.path.dirname(sys.modules["Allo"].__file__)
8586
m = os.path.join(d, "mixed")
8687
winSize = 500
88+
elif args.rna is not None:
89+
d = os.path.dirname(sys.modules["Allo"].__file__)
90+
m = os.path.join(d, "rna")
91+
winSize = 1000
8792
else:
8893
d = os.path.dirname(sys.modules["Allo"].__file__)
8994
m = os.path.join(d, "narrow")

Allo/allocation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def parseUniq(tempFile, winSize, cnn_scores, AS, rc, keep):
292292
return [cu, cf]
293293

294294

295-
def parseMulti(tempFile, winSize, genLand, modelName, cnn_scores, rc, keep, rmz, maxa):
295+
def parseMulti(tempFile, winSize, genLand, modelName, cnn_scores, rc, keep, rmz, maxa):
296296
numLoc = [0,0] #Keep info on average number of places read maps to
297297
#Getting trained CNN
298298
try:
@@ -306,6 +306,7 @@ def parseMulti(tempFile, winSize, genLand, modelName, cnn_scores, rc, keep, rmz,
306306
else:
307307
modelName = 0
308308
except:
309+
print("Model loading error", flush=True)
309310
print("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11")
310311
sys.exit(0)
311312

@@ -705,6 +706,7 @@ def parseMultiPE(tempFile, winSize, genLand, modelName, cnn_scores, rc, keep, rm
705706
else:
706707
modelName = 0
707708
except:
709+
print("PE model load error")
708710
print("Could not load Tensorflow model :( Allo was written with Tensorflow version 2.11")
709711
sys.exit(0)
710712

Allo/predictPeak.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616

1717
def predictNN(counts, winSize, model):
18-
winSize = math.floor(winSize/5) #Genome is already binned 5bps
1918
pic = np.zeros((100, 100), float)
2019
if sum(counts) == 0:
2120
pred = model(pic.reshape(-1,100,100), training=False)
@@ -40,6 +39,7 @@ def predictNN(counts, winSize, model):
4039
return pred.numpy()[0][0]
4140

4241
except:
42+
print("Prediction error.", flush=True)
4343
print("Could not predict with Tensorflow model :( Allo was written with Tensorflow version 2.11")
4444
sys.exit(0)
4545

0 commit comments

Comments
 (0)