source: gs3-extensions/structured-audio/trunk/bin/script/diarization.py@ 36211

Last change on this file since 36211 was 36211, checked in by davidb, 2 years ago

Progress towards having a StructuredAudioPlugin that uses pyannote.audio to process audio files

  • Property svn:executable set to *
File size: 2.9 KB
Line 
1#!/usr/bin/env python
2from pyannote.audio import Pipeline
3from pathlib import Path
4import os, sys, time, csv, argparse
5
6# remove available graphic cards to trigger cpu default
7os.environ["CUDA_VISIBLE_DEVICES"] = ""
8
9parser = argparse.ArgumentParser()
10parser.add_argument('inputfile', help="Audio input file")
11parser.add_argument('outputfile', nargs="?", help="Output file (.csv, optional)")
12parser.add_argument("--mingap", type=int, help="Minimum gap size in seconds between same-speaker segments", default=1)
13args = parser.parse_args()
14
15fileName = getattr(args, "inputfile")
16outputFile = getattr(args, "outputfile")
17gap_threshold = int(getattr(args, "mingap"));
18
19# replace file extension with .csv
20p = Path(fileName)
21if (outputFile == None):
22 outputFile = "temp_" + str(p.with_suffix(".csv"))
23
24if (os.path.exists(fileName)):
25 print("starting pyannote pipeline with file: " + fileName)
26 timeStart = time.perf_counter() # timer for performance monitoring
27 FILEIN = {'audio': fileName}
28 fileName, fileExtension = os.path.splitext(fileName)
29 pipeline = Pipeline.from_pretrained('pyannote/speaker-diarization')
30 diarization = pipeline(FILEIN)
31 try:
32 with open(outputFile, mode="w") as out_file:
33 csv_writer = csv.writer(out_file, delimiter=',')
34 for turn, _, speaker in diarization.itertracks(yield_label=True):
35 csv_writer.writerow([speaker, round(turn.start, 1), round(turn.end, 1)])
36 except Exception as e:
37 print(e)
38
39 print("pipeline completed.")
40 print(f"processTime: {time.perf_counter()-timeStart:.1f}s")
41
42 print("starting gap-removal with file: " + outputFile)
43 print("minimum gap: " + str(gap_threshold) + "s")
44 try:
45 with open(outputFile) as csv_file:
46 csv_reader = csv.reader(csv_file, delimiter=',')
47 names, starts, ends = [], [], []
48 for row in csv_reader: # convert csv to arrays
49 names.append(row[0])
50 starts.append(float(row[1]))
51 ends.append(float(row[2]))
52 os.remove(outputFile)
53 except Exception as e:
54 print(e)
55
56 try:
57 with open(outputFile.replace("temp_", ""), mode="w") as out_file:
58 num_items = len(names)
59 csv_writer = csv.writer(out_file, delimiter=',')
60 for i in range(1, num_items): # skip first line
61 if names[i] == names[i-1] and starts[i]-ends[i-1] < gap_threshold: # if prev and curr rows should be joined
62 starts[i] = starts[i-1] # move prev start time to current
63 else:
64 csv_writer.writerow([names[i-1], starts[i-1], ends[i-1]]) # write previous line to csv
65 csv_writer.writerow([names[num_items-1], starts[num_items-1], ends[num_items-1]]) # write last line
66 except Exception as e:
67 print(e)
68
69 print("gap-removal completed.")
70
71else:
72 print("error: file " + fileName + "cannot be found")
Note: See TracBrowser for help on using the repository browser.