import os import tensorflow as tf from tensorflow ...
創建於:2025年3月10日
創建於:2025年3月10日
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
def create_datasets(data_dir, img_size=(224, 224), batch_size=16):
"""
從資料夾產生訓練/驗證資料集。
資料夾結構需嚴格為四類:
data_dir/
├─ empty/
├─ normal/
├─ overlap/
└─ tilted/
"""
# 明確指定類別名稱與順序
class_names = ['empty', 'normal', 'overlap', 'tilted'] # 需與資料夾名稱一致
texttrain_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=123, image_size=img_size, batch_size=batch_size, class_names=class_names # 強制指定類別順序 ) val_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=123, image_size=img_size, batch_size=batch_size, class_names=class_names ) AUTOTUNE = tf.data.AUTOTUNE train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE) val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE) return train_ds, val_ds, class_names # 返回類別名稱
def build_model(input_shape=(224,224,3), num_classes=4):
"""
使用遷移學習,輸出層為多類分類(softmax + sparse_categorical_crossentropy)
"""
base_model = keras.applications.MobileNetV2(
input_shape=input_shape,
include_top=False,
weights='imagenet'
)
base_model.trainable = False
textinputs = keras.Input(shape=input_shape) x = keras.applications.mobilenet_v2.preprocess_input(inputs) # 正規化輸入 x = base_model(x, training=False) x = layers.GlobalAveragePooling2D()(x) x = layers.Dropout(0.2)(x) outputs = layers.Dense(num_classes, activation="softmax")(x) # 多類分類用 softmax model = keras.Model(inputs, outputs) model.compile( optimizer=keras.optimizers.Adam(learning_rate=1e-3), loss='sparse_categorical_crossentropy', # 多類分類交叉熵 metrics=['accuracy'] ) return model
def main():
data_dir = "dataset" # 資料夾需包含 empty, normal, overlap, tilted 四個子目錄
text# 建立資料集時明確指定類別 train_ds, val_ds, class_names = create_datasets(data_dir, img_size=(224,224), batch_size=16) print("類別名稱:", class_names) # 確認順序 wafer_model = build_model(input_shape=(224,224,3), num_classes=4) wafer_model.summary() # 添加訓練優化(早停法與模型檢查點) callbacks = [ keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True), keras.callbacks.ModelCheckpoint("best_model.h5", save_best_only=True) ] epochs = 100 # 早停法可減少訓練次數 history = wafer_model.fit( train_ds, epochs=epochs, validation_data=val_ds, callbacks=callbacks ) # 保存模型與類別名稱 wafer_model.save("wafer_model.h5") with open("class_names.json", "w") as f: import json json.dump(class_names, f) print("模型與類別名稱已儲存")
if name == "main":
main()
上面程式是一個缺陷訓練的程式, 請使用 python 程式寫出 : 用 filedialog 讀取 avi檔案, 利用上面訓練好的模型 (wafer_model.h5 ),使用yolo 物件追蹤, 可以即時框出這四種缺陷物件,並且把框出的結果儲存到目錄 bed 裡面
下面提供一個範例程式,示範以下流程:
tkinter.filedialog
讓使用者選擇 .avi
檔案。wafer_model.h5
對框選到的區域做判斷 (四類:empty
, normal
, overlap
, tilted
)。bed
資料夾。注意
- 此範例示意如何整合 YOLO 與你的分類模型,實作時你需要自行準備好 YOLO 的權重與程式碼。
- 範例中以下以 Ultralytics YOLO 為示範(
pip install ultralytics
),它提供了yolo_model = YOLO("yolov8n.pt")
這種使用方式,你可以換成任何你有的 YOLO 模型程式。- 若你使用的是 OpenCV DNN 方式(
net = cv2.dnn.readNet(...)
),請自行替換 YOLO 推論段落的程式即可。- 若要更完整的「跨影格」物件追蹤(Tracking),可再整合如 OpenCV Tracking API 或 ByteTrack 等追蹤方法;這裡僅示範「逐幀偵測 + 分類」。
- 下面程式可以直接寫成單一檔案執行,也可以拆分成多個檔案,視個人需求而定。
pythonimport os import cv2 import numpy as np import tensorflow as tf from tensorflow import keras from tkinter import Tk, filedialog # 若要使用 Ultralytics YOLO,需要先安裝: # pip install ultralytics from ultralytics import YOLO def main(): # 1. 透過 filedialog 選取檔案 root = Tk() root.withdraw() # 隱藏主視窗 video_path = filedialog.askopenfilename( title="選擇要偵測的 AVI 影片", filetypes=[("AVI files", "*.avi"), ("All files", "*.*")] ) if not video_path: print("未選擇檔案,程式結束。") return # 2. 讀取已訓練好的分類模型 (wafer_model.h5) wafer_model = keras.models.load_model("wafer_model.h5") # 這是你在訓練時的四類名稱 class_names = ['empty', 'normal', 'overlap', 'tilted'] # 3. 載入YOLO (此處以Ultralytics YOLO為例,請確保已經安裝) # 假設這裡是使用官方yolov8n預訓練權重,你可換成自己的權重檔 yolo_model = YOLO("yolov8n.pt") # 建立 "bed" 資料夾,存放結果 output_dir = "bed" os.makedirs(output_dir, exist_ok=True) # 4. 用 OpenCV 讀取 .avi cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("無法開啟影片: ", video_path) return # 影片輸出參數設定 fourcc = cv2.VideoWriter_fourcc(*'XVID') # 可根據需求改成其他編碼 fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 輸出檔案路徑(存在 bed 資料夾中) output_video_path = os.path.join(output_dir, "detected_output.avi") out_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) print("開始處理影片,請稍後...") while True: ret, frame = cap.read() if not ret: # 影片讀取完畢 break # 5. 使用 YOLO 偵測 (逐幀偵測) # results 會包含該帧所有偵測到的物件 results = yolo_model(frame, conf=0.25) # conf可自行調整閾值 # Ultralytics YOLO 可能一次傳回多個 result(batch),但預設這裡通常是 1 張圖 for r in results: # r.boxes 是該帧所有偵測框 boxes = r.boxes for box in boxes: # 取得框選座標 (xyxy 格式) # box.xyxy 為 tensor,需要轉為 int x1, y1, x2, y2 = map(int, box.xyxy[0]) # 6. ROI 區域裁切 (傳入你的 wafer_model 分類) roi = frame[y1:y2, x1:x2] # 如果框很小或無法正確裁切,要做檢查 if roi.size == 0: continue # 尺寸調整到 (224,224) 與你的模型一致 resized_roi = cv2.resize(roi, (224, 224)) # 正規化 (若你的模型需要, 例如 mobilenet_v2 的 preprocess_input) # 這裡簡單做範例:可依實際需要套用 # 由於在 build_model 時已經有 keras.applications.mobilenet_v2.preprocess_input # 這裡就可以直接把 [0,255] 圖片傳入即可,或自行手動呼叫 preprocess_input input_data = np.expand_dims(resized_roi, axis=0) # (1,224,224,3) # 分類推論 preds = wafer_model.predict(input_data) cls_index = np.argmax(preds, axis=1)[0] cls_label = class_names[cls_index] # 7. 在畫面上畫出框與類別名稱 cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(frame, cls_label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) # 將畫完框的影像寫入 output.avi out_writer.write(frame) # 若想邊寫檔邊顯示,可使用以下程式段 cv2.imshow("YOLO Detection + Classification", frame) if cv2.waitKey(1) & 0xFF == 27: # 按 ESC 離開 break cap.release() out_writer.release() cv2.destroyAllWindows() print("處理完成,結果已儲存在:", output_video_path) if __name__ == "__main__": main()
.avi
檔案。wafer_model.h5
做四類缺陷 (empty
, normal
, overlap
, tilted
) 的分類。bed
資料夾的 detected_output.avi
。ESC
,可以提前結束處理。你可以依據實際需求來修改:
- YOLO 模型的權重、閾值等 (
yolo_model(frame, conf=0.25)
之處)。- 視訊輸出編碼器 (
fourcc
) 與檔案格式。- 是否在影格上做其他標註或儲存資訊等。
如此即可達到「讀取 AVI → YOLO 偵測 → 使用你訓練好的 wafer_model 進行分類 → 即時框出缺陷物 → 輸出結果」的流程。
import cv2
from sort import *
import math
import numpy as np
from ultralytics import YOLO
import cvzone
from time import sleep
cap = cv2.VideoCapture('Video0.avi')
model = YOLO(r'D:\venv\OTMS1\best.pt')
classnames = []
with open('classes.txt', 'r') as f:
classnames = f.read().splitlines()
tracker = Sort(max_age=20)
line = [120, 1050, 920, 1050]
wafer_count = 0
counted_ids = [] # 用來記錄已經計數過的物件 ID
while cap.isOpened():
sleep(0.1) # 可依需求調整每幀間的延遲時間
ret, frame = cap.read()
if not ret:
cap = cv2.VideoCapture('Video0.avi')
#continue
text# 使用 YOLO 模型進行偵測 (stream 模式逐張處理) result = model(frame, stream=True) detections = [] # 存放每一幀的偵測結果 [x1, y1, x2, y2, conf] for info in result: boxes = info.boxes for box in boxes: # 取得邊界框座標與信心分數 x1, y1, x2, y2 = box.xyxy[0] conf = box.conf[0] classindex = int(box.cls[0]) x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) conf = float(conf) # 將每個偵測結果加入列表 (注意:這邊皆以 wafer 為例,若要僅追蹤特定類別可加入條件篩選) detections.append([x1, y1, x2, y2, conf]) # 繪製原始偵測框與標籤 cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) cvzone.putTextRect(frame, f'{classnames[classindex]} {math.ceil(conf * 100)}%', [x1 + 8, y1 - 12], thickness=2, scale=3) # 計算並標示偵測物件中心點(紅點) cx = int((x1 + x2) / 2) cy = int((y1 + y2) / 2) cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1) # 若有偵測結果則轉換為 numpy 陣列,否則傳入空陣列 if len(detections) > 0: dets = np.array(detections) else: dets = np.empty((0, 5)) # 更新追蹤器,取得追蹤物件 (每列 [x1, y1, x2, y2, track_id]) tracks = tracker.update(dets) # 在影像上劃出計數用的直線 cv2.line(frame, (line[0], line[1]), (line[2], line[3]), (255, 0, 0), thickness=10) # 遍歷每一個追蹤物件 for track in tracks: x1, y1, x2, y2, track_id = track x1, y1, x2, y2, track_id = int(x1), int(y1), int(x2), int(y2), int(track_id) # 繪製追蹤框與 ID cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2) cvzone.putTextRect(frame, str(track_id), [x1, y1], scale=2, thickness=2) # 計算物件中心點 (此處雖然已在偵測時標示,但依需求可再次標示) cx = int((x1 + x2) / 2) cy = int((y1 + y2) / 2) cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1) # 判斷若中心點位於計數線附近,且尚未計數,則累計 wafer 數量 if (line[0] <= cx <= line[2]) and abs(cy - line[1]) < 10: if track_id not in counted_ids: wafer_count += 1 counted_ids.append(track_id) # 在影像上顯示計數數量 cvzone.putTextRect(frame, f'Count: {wafer_count}', [50, 50], scale=3, thickness=2) cv2.namedWindow('Wafer Detection', cv2.WINDOW_NORMAL) cv2.resizeWindow('Wafer Detection', 500, 700) cv2.imshow('Wafer Detection', frame) cv2.waitKey(1)
cap.release()
cv2.destroyAllWindows()
上面程式沒有正確數出正確的wafer 的數量出來,請修改一下
下面的範例修改了原始程式用「跨線計數」的典型做法,來避免重複或漏算的問題。其核心邏輯是:
若你只是想用「任何物件只要觸碰線一次就算 1」的簡易做法,也可以保留
counted_ids
這樣的集合,當某個 track ID 第一次跨線就加入集合,之後就不再計數。
這裡提供兩種範例寫法,你可擇一使用。
y = line[1]
(同時 line[0]
, line[2]
代表線在畫面 X 方向的起訖)。(cx, cy)
與之前上一幀的中心點 (pcx, pcy)
比較:
pcy < line_y
(在線上方)且這一幀 cy >= line_y
(到達或跨越線下方),代表剛好「由上往下」跨線,計數 +1。previous_center[track_id] = cy
,以防它在同一條線來回上下晃動時多次計數。以下提供完整修改範例:
pythonimport cv2 from sort import * import math import numpy as np from ultralytics import YOLO import cvzone from time import sleep # ============ 載入影片與模型 ============= cap = cv2.VideoCapture('Video0.avi') model = YOLO(r'D:\venv\OTMS1\best.pt') # 讀取類別名稱 classnames = [] with open('classes.txt', 'r') as f: classnames = f.read().splitlines() # 初始化 SORT 追蹤器 tracker = Sort(max_age=20) # ============ 計數線定義 (水平方向) ============ # line = [x1, y, x2, y] 代表此線的左端點(x1, y),與右端點(x2, y) line = [120, 1050, 920, 1050] line_y = line[1] wafer_count = 0 # 保存上一幀各物件 (track_id) 的「中心 y 座標」: { track_id: previous_cy } previous_center = {} # 如果想只要跨越一次就不再重複計數,可以用一個 set 紀錄已計數的 track_id counted_ids = set() while cap.isOpened(): sleep(0.1) # 依需求可調整延遲 ret, frame = cap.read() if not ret: # 讀取不到影像就跳回開頭或直接break cap.set(cv2.CAP_PROP_POS_FRAMES, 0) continue # 1. 使用 YOLO 模型進行偵測 (stream 模式逐張處理) result = model(frame, stream=True) detections = [] # [x1, y1, x2, y2, conf] for info in result: boxes = info.boxes for box in boxes: x1, y1, x2, y2 = box.xyxy[0] conf = box.conf[0] classindex = int(box.cls[0]) x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) conf = float(conf) # 收集偵測框,後面給 SORT 追蹤 detections.append([x1, y1, x2, y2, conf]) # 繪製 YOLO 偵測框 (綠色) 與標籤 cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) cvzone.putTextRect(frame, f'{classnames[classindex]} {math.ceil(conf * 100)}%', [x1 + 8, y1 - 12], thickness=2, scale=3) # 在偵測階段也可標示中心點 cx = (x1 + x2) // 2 cy = (y1 + y2) // 2 cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1) # 2. 更新追蹤器:將本幀的偵測結果 (dets) 丟給 SORT,得到追蹤物件 if len(detections) > 0: dets = np.array(detections) else: dets = np.empty((0, 5)) tracks = tracker.update(dets) # 3. 繪製「計數線」(藍色) cv2.line(frame, (line[0], line_y), (line[2], line_y), (255, 0, 0), thickness=10) # 4. 遍歷追蹤到的物件 for track in tracks: x1, y1, x2, y2, track_id = track x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) track_id = int(track_id) # 追蹤框 (藍色) 與 ID cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2) cvzone.putTextRect(frame, str(track_id), [x1, y1], scale=2, thickness=2) # 計算中心點 cx = (x1 + x2) // 2 cy = (y1 + y2) // 2 cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1) # =========== (A) 簡易做法:只要物件中心「碰到線」就計數一次 =========== # 把 crosses_line 的條件改成你想要的「cx 在 line 的 x 範圍內,且 cy 與 line_y 的差距很小」等等 # crosses_line = (line[0] <= cx <= line[2]) and (abs(cy - line_y) < 10) # if crosses_line and track_id not in counted_ids: # wafer_count += 1 # counted_ids.add(track_id) # =========== (B) 常見做法:檢查中心點是否跨越線(由上到下) =========== if track_id not in counted_ids: # 若該追蹤 ID 還沒計數過,檢查是否「上一幀在線上方、這幀在線下方」 pcy = previous_center.get(track_id, None) # 上一幀的 y if pcy is not None: # 檢查是否剛好由 上 -> 下 (或接觸線) # 且中心x在計數線範圍內 if pcy < line_y <= cy and (line[0] <= cx <= line[2]): wafer_count += 1 # 計數一次就把它加入 counted_ids,代表已計數,不再重複計數 counted_ids.add(track_id) # 更新本幀中心點當作下一次的 previous previous_center[track_id] = cy # 5. 在影像上顯示計數 cvzone.putTextRect(frame, f'Count: {wafer_count}', [50, 50], scale=3, thickness=2) # 6. 顯示影像 cv2.namedWindow('Wafer Detection', cv2.WINDOW_NORMAL) cv2.resizeWindow('Wafer Detection', 500, 700) cv2.imshow('Wafer Detection', frame) if cv2.waitKey(1) & 0xFF == 27: # 按 ESC 離開 break cap.release() cv2.destroyAllWindows()
計數線位置
line = [120, 1050, 920, 1050]
,表示一條從 (120, 1050)
到 (920, 1050)
的水平線。1050
,那麼此線可能完全看不到、或物件也根本沒走到那麼下方。line = [x1, y, x2, y]
使其在畫面中間或合適位置。跨線方向
pcy < line_y <= cy
)。若你要由下往上,也可反向檢查 (pcy > line_y >= cy
)。是否只計一次
counted_ids
的方式,可以保證每個 track ID 只被計數一次。如果你想一個 ID 可以反覆跨線多次,都要計數,就不需要 counted_ids
,改為只要偵測到跨線就 wafer_count += 1
。如果 SORT 斷追蹤
max_age
、或使用更穩定的追蹤策略。如果你想保持原本 (line[0] <= cx <= line[2]) and abs(cy - line[1]) < 10
的條件,只要確保它「只會被算一次」就好,可以這樣做:
python# ... line = [120, 1050, 920, 1050] wafer_count = 0 counted_ids = set() # 上一幀中心位置 previous_center = {} while cap.isOpened(): ret, frame = cap.read() if not ret: cap.set(cv2.CAP_PROP_POS_FRAMES, 0) continue # ... YOLO檢測 + SORT 追蹤 ... for track in tracks: x1, y1, x2, y2, track_id = track x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) track_id = int(track_id) cx = (x1 + x2) // 2 cy = (y1 + y2) // 2 # 只有當此 track 尚未被計數過時才檢查 if track_id not in counted_ids: # 如果上一幀有記錄它的位置 pcy = previous_center.get(track_id, None) if pcy is not None: # 假設要偵測該物件中心點接近計數線 (±10px) crosses_line = (line[0] <= cx <= line[2]) and (abs(cy - line[1]) < 10) # 若確定想「上一幀在上方,這幀到線附近」再算,可額外檢查 (pcy < line[1] and cy >= line[1]) if crosses_line: wafer_count += 1 counted_ids.add(track_id) # 更新 center previous_center[track_id] = cy # ... 繪圖, 顯示 ...
原理一樣,只是你仍沿用「物件中心點落在計數線±10px 區域中」這種條件,但使用 counted_ids
確保同一個 track ID 只算一次。
counted_ids
(或更新 previous_center
)避免重複計算。以上方法能大幅降低「計數不正確」的情況,請你根據實際需求與影像狀況調整參數與判斷條件。祝你順利完成計數功能!