1 | #!/usr/bin/env python
|
---|
2 | from pyannote.audio import Pipeline
|
---|
3 | from pathlib import Path
|
---|
4 | import os, sys, time, csv, argparse
|
---|
5 |
|
---|
6 | # remove available graphic cards to trigger cpu default
|
---|
7 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
---|
8 |
|
---|
9 | parser = argparse.ArgumentParser()
|
---|
10 | parser.add_argument('inputfile', help="Audio input file")
|
---|
11 | parser.add_argument('outputfile', nargs="?", help="Output file (.csv, optional)")
|
---|
12 | parser.add_argument("--mingap", type=int, help="Minimum gap size in seconds between same-speaker segments", default=1)
|
---|
13 | args = parser.parse_args()
|
---|
14 |
|
---|
15 | fileName = getattr(args, "inputfile")
|
---|
16 | outputFile = getattr(args, "outputfile")
|
---|
17 | gap_threshold = int(getattr(args, "mingap"));
|
---|
18 |
|
---|
19 | # replace file extension with .csv
|
---|
20 | p = Path(fileName)
|
---|
21 | if (outputFile == None):
|
---|
22 | outputFile = "temp_" + str(p.with_suffix(".csv"))
|
---|
23 |
|
---|
24 | if (os.path.exists(fileName)):
|
---|
25 | print("starting pyannote pipeline with file: " + fileName)
|
---|
26 | timeStart = time.perf_counter() # timer for performance monitoring
|
---|
27 | FILEIN = {'audio': fileName}
|
---|
28 | fileName, fileExtension = os.path.splitext(fileName)
|
---|
29 | pipeline = Pipeline.from_pretrained('pyannote/speaker-diarization')
|
---|
30 | diarization = pipeline(FILEIN)
|
---|
31 | try:
|
---|
32 | with open(outputFile, mode="w") as out_file:
|
---|
33 | csv_writer = csv.writer(out_file, delimiter=',')
|
---|
34 | for turn, _, speaker in diarization.itertracks(yield_label=True):
|
---|
35 | csv_writer.writerow([speaker, round(turn.start, 1), round(turn.end, 1)])
|
---|
36 | except Exception as e:
|
---|
37 | print(e)
|
---|
38 |
|
---|
39 | print("pipeline completed.")
|
---|
40 | print(f"processTime: {time.perf_counter()-timeStart:.1f}s")
|
---|
41 |
|
---|
42 | print("starting gap-removal with file: " + outputFile)
|
---|
43 | print("minimum gap: " + str(gap_threshold) + "s")
|
---|
44 | try:
|
---|
45 | with open(outputFile) as csv_file:
|
---|
46 | csv_reader = csv.reader(csv_file, delimiter=',')
|
---|
47 | names, starts, ends = [], [], []
|
---|
48 | for row in csv_reader: # convert csv to arrays
|
---|
49 | names.append(row[0])
|
---|
50 | starts.append(float(row[1]))
|
---|
51 | ends.append(float(row[2]))
|
---|
52 | os.remove(outputFile)
|
---|
53 | except Exception as e:
|
---|
54 | print(e)
|
---|
55 |
|
---|
56 | try:
|
---|
57 | with open(outputFile.replace("temp_", ""), mode="w") as out_file:
|
---|
58 | num_items = len(names)
|
---|
59 | csv_writer = csv.writer(out_file, delimiter=',')
|
---|
60 | for i in range(1, num_items): # skip first line
|
---|
61 | if names[i] == names[i-1] and starts[i]-ends[i-1] < gap_threshold: # if prev and curr rows should be joined
|
---|
62 | starts[i] = starts[i-1] # move prev start time to current
|
---|
63 | else:
|
---|
64 | csv_writer.writerow([names[i-1], starts[i-1], ends[i-1]]) # write previous line to csv
|
---|
65 | csv_writer.writerow([names[num_items-1], starts[num_items-1], ends[num_items-1]]) # write last line
|
---|
66 | except Exception as e:
|
---|
67 | print(e)
|
---|
68 |
|
---|
69 | print("gap-removal completed.")
|
---|
70 |
|
---|
71 | else:
|
---|
72 | print("error: file " + fileName + "cannot be found")
|
---|