列车轨道及其障碍物检测相关算法

目录

一、开源算法来源

1.1  列车轨道+障碍物检测(AI算法)

1.2 列车轨道(滤波算法)

1.3 列车轨道(滤波算法)

二、运行代码

2.3.1 具体流程

2.3.2 详细代码

2.3.3 运行步骤


一、开源算法来源

1.1  列车轨道+障碍物检测(AI算法)

GitHub - ELKYang/RailWay_Detection: 电车轨道与障碍物检测(SJTU数字图像处理课程设计)

1.2 列车轨道(滤波算法)

火车轨道铁路轨道检测识别(弯轨+直轨)通用性(Python源码+讲解)_opencv 火车识别-CSDN博客

1.3 列车轨道(滤波算法)

铁路检测概念验证:基于摄像头的自动列车系统-CSDN博客

  

二、运行代码

2.3.1 具体流程

列车轨道及其障碍物检测相关算法_第1张图片

2.3.2 详细代码

import cv2
import numpy as np
import argparse
from scipy.special import comb  # 用于计算组合数,生成贝塞尔曲线

# args setting
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('-i', "--input", help="input file video")                                                  # 指定输入视频
parser.add_argument('--leftPoint', type=int, help="Left rail offset", default=400)       # 左边界裁剪位置
parser.add_argument('--rightPoint', type=int, help="Right rail offset", default=900) # 右边界裁剪位置
parser.add_argument('--topPoint', type=int, help="Top rail offset", default=330)       # 顶部裁剪位置
args = parser.parse_args()


def main():
    # 下载视频
    cap = VideoCapture(args.input)

    # 初始化轨道检测参数
    expt_startLeft = args.leftPoint        # 左轨道水平起始点
    expt_startRight = args.rightPoint  # 右轨道水平终止点
    expt_startTop = args.topPoint        # 垂直方向起始点 

    # 滑动窗口历史数据初始化
    left_maxpoint = [0] * 50                    # 存储左轨道检测点的历史数据
    right_maxpoint = [195] * 50            # 存储右轨道检测点的历史数据

    # 边缘检测卷积核(强调垂直边缘)
    kernel = np.array([
        [-1, 1, 0, 1, -1],
        [-1, 1, 0, 1, -1],
        [-1, 1, 0, 1, -1],
        [-1, 1, 0, 1, -1],
        [-1, 1, 0, 1, -1]
    ])
    # 视频处理循环控制变量
    r = True         # 帧读取状态
    first = True  # 首帧标志
    
    while r is True:
        r, frame = cap.read()  # 读取视频帧
        if frame is None:
            break

        # 图像预处理 --------------------------------------------------------
        # 裁剪有效区域(ROI)
        valid_frame = frame[expt_startTop:, expt_startLeft:expt_startRight]
        # original_frame = valid_frame.copy()

        # 转换为灰度图
        gray_frame = cv2.cvtColor(valid_frame, cv2.COLOR_BGR2GRAY)

       # 直方图均衡化(增强对比度)
        histeqaul_frame = cv2.equalizeHist(gray_frame)

        # 高斯模糊(降噪)
        blur_frame = cv2.GaussianBlur(histeqaul_frame, (5, 5), 5)

        # 时序融合:当前帧与历史帧加权混合(增强连续性)
        if first is True:
            merge_frame = blur_frame
            first = False
            old_valid_frame = merge_frame.copy()
        else:
            merge_frame = cv2.addWeighted(blur_frame, 0.2, old_valid_frame, 0.8, 0)
            old_valid_frame = merge_frame.copy()

         # 边缘增强:应用自定义卷积核检测垂直边缘
        conv_frame = cv2.filter2D(merge_frame, -1, kernel)

        # 轨道检测 ---------------------------------------------------------
        # 滑动窗口参数初始化
        sliding_window = [20, 190, 200, 370]  # 左右窗口初始位置 [左窗左,左窗右,右窗左,右窗右]
        slide_interval = 15                                      # 窗口垂直移动步长
        slide_height = 15                                         # 窗口高度
        slide_width = 60                                          # 窗口初始宽度

        # 存储检测点
        left_points = []     # 左轨道点集
        right_points = []  # 右轨道点集

        # define count value
        count = 0  # 滑动窗口计数器
         # 从下往上滑动窗口(340px -> 40px)
        for i in range(400, 20, -slide_interval):
            # 左窗口边缘响应计算
            left_edge = conv_frame[i:i + slide_height, sliding_window[0]:sliding_window[1]].sum(axis=0)
            # 右窗口边缘响应计算
            right_edge = conv_frame[i:i + slide_height, sliding_window[2]:sliding_window[3]].sum(axis=0)

            # 左轨道处理
            if left_edge.argmax() > 0:
                # 计算最大响应位置
                left_maxindex = sliding_window[0] + left_edge.argmax()
                left_maxpoint[count] = left_maxindex
                # 在检测点绘制白点
                cv2.line(valid_frame, (left_maxindex, i + int(slide_height / 2)),
                         (left_maxindex, i + int(slide_height / 2)), (255, 255, 255), 5, cv2.LINE_AA)
                left_points.append([left_maxindex, i + int(slide_height / 2)])
                # 动态调整窗口位置(逐步缩小范围)
                sliding_window[0] = max(0, left_maxindex - int(slide_width / 4 + (slide_width + 10) / (count + 1)))
                sliding_window[1] = min(390, left_maxindex + int(slide_width / 4 + (slide_width + 10) / (count + 1)))
                # 绘制滑动窗口
                cv2.rectangle(valid_frame, (sliding_window[0], i + slide_height), (sliding_window[1], i), (0, 255, 0),1)

              # 右轨道处理(逻辑同上)
            if right_edge.argmax() > 0:
                right_maxindex = sliding_window[2] + right_edge.argmax()
                right_maxpoint[count] = right_maxindex
                cv2.line(valid_frame, (right_maxindex, i + int(slide_height / 2)),
                         (right_maxindex, i + int(slide_height / 2)), (255, 255, 255), 5, cv2.LINE_AA)
                right_points.append([right_maxindex, i + int(slide_height / 2)])
                sliding_window[2] = max(0, right_maxindex - int(slide_width / 4 + (slide_width + 10) / (count + 1)))
                sliding_window[3] = min(390, right_maxindex + int(slide_width / 4 + (slide_width + 10) / (count + 1)))
                cv2.rectangle(valid_frame, (sliding_window[2], i + slide_height), (sliding_window[3], i), (0, 0, 255),
                              1)
            count += 1

        # 贝塞尔曲线拟合 ---------------------------------------------------
        # 生成曲线坐标点
        bezier_left_xval, bezier_left_yval = bezier_curve(left_points, 100)
        bezier_right_xval, bezier_right_yval = bezier_curve(right_points, 100)
        # 绘制曲线
        bezier_left_points = []
        bezier_right_points = []
        try:
             # 绘制左轨道曲线(红色)
            old_point = (bezier_left_xval[0], bezier_left_yval[0])
            for point in zip(bezier_left_xval, bezier_left_yval):
                cv2.line(valid_frame, old_point, point, (0, 0, 255), 2, cv2.LINE_AA)
                old_point = point
                bezier_left_points.append(point)
            # 绘制右轨道曲线(蓝色)
            old_point = (bezier_right_xval[0], bezier_right_yval[0])
            for point in zip(bezier_right_xval, bezier_right_yval):
                cv2.line(valid_frame, old_point, point, (255, 0, 0), 2, cv2.LINE_AA)
                old_point = point
                bezier_right_points.append(point)
        except IndexError:
            pass   # 处理空点集的情况
        '''
        cv2.imshow('frame', np.vstack([
            np.hstack([valid_frame,
                       original_frame,
                       cv2.cvtColor(histeqaul_frame, cv2.COLOR_GRAY2BGR)]),
            np.hstack([cv2.cvtColor(blur_frame, cv2.COLOR_GRAY2BGR),
                       cv2.cvtColor(merge_frame, cv2.COLOR_GRAY2BGR),
                       cv2.cvtColor(conv_frame, cv2.COLOR_GRAY2BGR)])
        ]))
        '''
        # 结果显示 ---------------------------------------------------------
        cv2.imshow('Video', valid_frame)
        cv2.waitKey(10)   #调整视频的播放速度
    print('finish')


# 视频捕获封装类
class VideoCapture:
    def __init__(self, path):
        # Using OpenCV to capture from device 0. If you have trouble capturing
        # from a webcam, comment the line below out and use a video file
        # instead.
        self.video = cv2.VideoCapture(path)  # 初始化视频捕获
        # If you decide to use video.mp4, you must have this file in the folder
        # as the main.py.
        # self.video = cv2.VideoCapture('video.mp4')

    def __del__(self):   # 释放资源
        self.video.release()
     # 返回读取状态和帧数据
    def read(self):
        # Grab a single frame of video
        ret, frame = self.video.read()
        return frame is not None, frame


# 贝塞尔曲线生成函数
def bezier_curve(points, ntimes=1000):
    """
       Given a set of control points, return the
       bezier curve defined by the control points.

       points should be a list of lists, or list of tuples
       such as [ [1,1],
                 [2,3],
                 [4,5], ..[Xn, Yn] ]
        ntimes is the number of time steps, defaults to 1000

        See http://processingjs.nihongoresources.com/bezierinfo/
    """

    def bernstein_poly(i, n, t):
        """
         The Bernstein polynomial of n, i as a function of t
        """
        return comb(n, i) * (t ** (n - i)) * (1 - t) ** i

    nPoints = len(points)
    xPoints = np.array([p[0] for p in points])
    yPoints = np.array([p[1] for p in points])

    t = np.linspace(0.0, 1.0, ntimes)

    polynomial_array = np.array([bernstein_poly(i, nPoints - 1, t) for i in range(0, nPoints)])

    xvals = np.dot(xPoints, polynomial_array)
    yvals = np.dot(yPoints, polynomial_array)
    return xvals.astype('int32'), yvals.astype('int32')


def nothing(value):
    pass


if __name__ == '__main__':
    main()

2.3.3 运行步骤

1、 进入文件夹,打开终端

2、 安装需求工具

  • pip install -r requirements.txt

3、运行代码

  • python main.py --input video/test1.mp4

4、实验结果

你可能感兴趣的:(车道检测研究,列车轨道检测)