目录
一、开源算法来源
1.1 列车轨道+障碍物检测(AI算法)
1.2 列车轨道(滤波算法)
1.3 列车轨道(滤波算法)
二、运行代码
2.3.1 具体流程
2.3.2 详细代码
2.3.3 运行步骤
GitHub - ELKYang/RailWay_Detection: 电车轨道与障碍物检测(SJTU数字图像处理课程设计)
火车轨道铁路轨道检测识别(弯轨+直轨)通用性(Python源码+讲解)_opencv 火车识别-CSDN博客
铁路检测概念验证:基于摄像头的自动列车系统-CSDN博客
import cv2
import numpy as np
import argparse
from scipy.special import comb # 用于计算组合数,生成贝塞尔曲线
# args setting
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('-i', "--input", help="input file video") # 指定输入视频
parser.add_argument('--leftPoint', type=int, help="Left rail offset", default=400) # 左边界裁剪位置
parser.add_argument('--rightPoint', type=int, help="Right rail offset", default=900) # 右边界裁剪位置
parser.add_argument('--topPoint', type=int, help="Top rail offset", default=330) # 顶部裁剪位置
args = parser.parse_args()
def main():
# 下载视频
cap = VideoCapture(args.input)
# 初始化轨道检测参数
expt_startLeft = args.leftPoint # 左轨道水平起始点
expt_startRight = args.rightPoint # 右轨道水平终止点
expt_startTop = args.topPoint # 垂直方向起始点
# 滑动窗口历史数据初始化
left_maxpoint = [0] * 50 # 存储左轨道检测点的历史数据
right_maxpoint = [195] * 50 # 存储右轨道检测点的历史数据
# 边缘检测卷积核(强调垂直边缘)
kernel = np.array([
[-1, 1, 0, 1, -1],
[-1, 1, 0, 1, -1],
[-1, 1, 0, 1, -1],
[-1, 1, 0, 1, -1],
[-1, 1, 0, 1, -1]
])
# 视频处理循环控制变量
r = True # 帧读取状态
first = True # 首帧标志
while r is True:
r, frame = cap.read() # 读取视频帧
if frame is None:
break
# 图像预处理 --------------------------------------------------------
# 裁剪有效区域(ROI)
valid_frame = frame[expt_startTop:, expt_startLeft:expt_startRight]
# original_frame = valid_frame.copy()
# 转换为灰度图
gray_frame = cv2.cvtColor(valid_frame, cv2.COLOR_BGR2GRAY)
# 直方图均衡化(增强对比度)
histeqaul_frame = cv2.equalizeHist(gray_frame)
# 高斯模糊(降噪)
blur_frame = cv2.GaussianBlur(histeqaul_frame, (5, 5), 5)
# 时序融合:当前帧与历史帧加权混合(增强连续性)
if first is True:
merge_frame = blur_frame
first = False
old_valid_frame = merge_frame.copy()
else:
merge_frame = cv2.addWeighted(blur_frame, 0.2, old_valid_frame, 0.8, 0)
old_valid_frame = merge_frame.copy()
# 边缘增强:应用自定义卷积核检测垂直边缘
conv_frame = cv2.filter2D(merge_frame, -1, kernel)
# 轨道检测 ---------------------------------------------------------
# 滑动窗口参数初始化
sliding_window = [20, 190, 200, 370] # 左右窗口初始位置 [左窗左,左窗右,右窗左,右窗右]
slide_interval = 15 # 窗口垂直移动步长
slide_height = 15 # 窗口高度
slide_width = 60 # 窗口初始宽度
# 存储检测点
left_points = [] # 左轨道点集
right_points = [] # 右轨道点集
# define count value
count = 0 # 滑动窗口计数器
# 从下往上滑动窗口(340px -> 40px)
for i in range(400, 20, -slide_interval):
# 左窗口边缘响应计算
left_edge = conv_frame[i:i + slide_height, sliding_window[0]:sliding_window[1]].sum(axis=0)
# 右窗口边缘响应计算
right_edge = conv_frame[i:i + slide_height, sliding_window[2]:sliding_window[3]].sum(axis=0)
# 左轨道处理
if left_edge.argmax() > 0:
# 计算最大响应位置
left_maxindex = sliding_window[0] + left_edge.argmax()
left_maxpoint[count] = left_maxindex
# 在检测点绘制白点
cv2.line(valid_frame, (left_maxindex, i + int(slide_height / 2)),
(left_maxindex, i + int(slide_height / 2)), (255, 255, 255), 5, cv2.LINE_AA)
left_points.append([left_maxindex, i + int(slide_height / 2)])
# 动态调整窗口位置(逐步缩小范围)
sliding_window[0] = max(0, left_maxindex - int(slide_width / 4 + (slide_width + 10) / (count + 1)))
sliding_window[1] = min(390, left_maxindex + int(slide_width / 4 + (slide_width + 10) / (count + 1)))
# 绘制滑动窗口
cv2.rectangle(valid_frame, (sliding_window[0], i + slide_height), (sliding_window[1], i), (0, 255, 0),1)
# 右轨道处理(逻辑同上)
if right_edge.argmax() > 0:
right_maxindex = sliding_window[2] + right_edge.argmax()
right_maxpoint[count] = right_maxindex
cv2.line(valid_frame, (right_maxindex, i + int(slide_height / 2)),
(right_maxindex, i + int(slide_height / 2)), (255, 255, 255), 5, cv2.LINE_AA)
right_points.append([right_maxindex, i + int(slide_height / 2)])
sliding_window[2] = max(0, right_maxindex - int(slide_width / 4 + (slide_width + 10) / (count + 1)))
sliding_window[3] = min(390, right_maxindex + int(slide_width / 4 + (slide_width + 10) / (count + 1)))
cv2.rectangle(valid_frame, (sliding_window[2], i + slide_height), (sliding_window[3], i), (0, 0, 255),
1)
count += 1
# 贝塞尔曲线拟合 ---------------------------------------------------
# 生成曲线坐标点
bezier_left_xval, bezier_left_yval = bezier_curve(left_points, 100)
bezier_right_xval, bezier_right_yval = bezier_curve(right_points, 100)
# 绘制曲线
bezier_left_points = []
bezier_right_points = []
try:
# 绘制左轨道曲线(红色)
old_point = (bezier_left_xval[0], bezier_left_yval[0])
for point in zip(bezier_left_xval, bezier_left_yval):
cv2.line(valid_frame, old_point, point, (0, 0, 255), 2, cv2.LINE_AA)
old_point = point
bezier_left_points.append(point)
# 绘制右轨道曲线(蓝色)
old_point = (bezier_right_xval[0], bezier_right_yval[0])
for point in zip(bezier_right_xval, bezier_right_yval):
cv2.line(valid_frame, old_point, point, (255, 0, 0), 2, cv2.LINE_AA)
old_point = point
bezier_right_points.append(point)
except IndexError:
pass # 处理空点集的情况
'''
cv2.imshow('frame', np.vstack([
np.hstack([valid_frame,
original_frame,
cv2.cvtColor(histeqaul_frame, cv2.COLOR_GRAY2BGR)]),
np.hstack([cv2.cvtColor(blur_frame, cv2.COLOR_GRAY2BGR),
cv2.cvtColor(merge_frame, cv2.COLOR_GRAY2BGR),
cv2.cvtColor(conv_frame, cv2.COLOR_GRAY2BGR)])
]))
'''
# 结果显示 ---------------------------------------------------------
cv2.imshow('Video', valid_frame)
cv2.waitKey(10) #调整视频的播放速度
print('finish')
# 视频捕获封装类
class VideoCapture:
def __init__(self, path):
# Using OpenCV to capture from device 0. If you have trouble capturing
# from a webcam, comment the line below out and use a video file
# instead.
self.video = cv2.VideoCapture(path) # 初始化视频捕获
# If you decide to use video.mp4, you must have this file in the folder
# as the main.py.
# self.video = cv2.VideoCapture('video.mp4')
def __del__(self): # 释放资源
self.video.release()
# 返回读取状态和帧数据
def read(self):
# Grab a single frame of video
ret, frame = self.video.read()
return frame is not None, frame
# 贝塞尔曲线生成函数
def bezier_curve(points, ntimes=1000):
"""
Given a set of control points, return the
bezier curve defined by the control points.
points should be a list of lists, or list of tuples
such as [ [1,1],
[2,3],
[4,5], ..[Xn, Yn] ]
ntimes is the number of time steps, defaults to 1000
See http://processingjs.nihongoresources.com/bezierinfo/
"""
def bernstein_poly(i, n, t):
"""
The Bernstein polynomial of n, i as a function of t
"""
return comb(n, i) * (t ** (n - i)) * (1 - t) ** i
nPoints = len(points)
xPoints = np.array([p[0] for p in points])
yPoints = np.array([p[1] for p in points])
t = np.linspace(0.0, 1.0, ntimes)
polynomial_array = np.array([bernstein_poly(i, nPoints - 1, t) for i in range(0, nPoints)])
xvals = np.dot(xPoints, polynomial_array)
yvals = np.dot(yPoints, polynomial_array)
return xvals.astype('int32'), yvals.astype('int32')
def nothing(value):
pass
if __name__ == '__main__':
main()
1、 进入文件夹,打开终端
2、 安装需求工具
- pip install -r requirements.txt
3、运行代码
- python main.py --input video/test1.mp4
4、实验结果