首页 >web前端 >PS教程 >Photoshop中磁力套索的一种简陋实现(基于Python)

Photoshop中磁力套索的一种简陋实现(基于Python)

高洛峰
高洛峰原创
2017-02-18 13:38:093289浏览

经常用Photoshop的人应该熟悉磁力套索(Magnetic Lasso)这个功能,就是人为引导下的抠图辅助工具。在研发领域一般不这么叫,通常管这种边缘提取的办法叫Intelligent Scissors或者Livewire。

本来是给一个图像分割项目算法评估时的Python框架,觉得有点意思,就稍稍拓展了一下,用PyQt加了个壳,在非常简陋的程度上模拟了一下的磁力套索功能。为什么简陋:1) 只实现了最基本的边缘查找。路径冷却,动态训练,鼠标位置修正都没有,更别提曲线闭合,抠图,Alpha Matting等等;2) 没考虑性能规范,只为了书写方便;3) 我对Qt了解很浅,至今不会写Signal-Slot,不知道GUI写得是否合理;4) 没调试。

Photoshop中磁力套索的一种简陋实现(基于Python)

基本算法

相关算法我并没有做很深入的调研,不过相信这类应用中影响力最大的算法是来源于[1],也是本文的主要参考,基本思想是把图片看成是一个无向图,相邻像素之间就可以计算出一个局部cost,于是就转化成了最短路径问题了,接下来就是基于Dijkstra算法产生路径,就是需要提取的边缘。主要涉及的算法有两部分:1) 相邻像素的cost计算;2) 最短路径算法。

边缘检测

计算相邻像素cost的最终目的还是为了寻找边缘,所以本质还是边缘检测。基本思想是,通过各种不同手段检测边缘,并且根据检测到的强度来求加权值,作为cost。从最短路径的角度来说,就是边缘越明显的地方,cost的值越小。[1]中的建议是用三种指标求加权:1) 边缘检测算子;2) 梯度强度(Gradient Magnitude);3) 梯度方向(Gradient Direction)。本文的方法和[1]有那么一些不一样,因为懒,用了OpenCV中的Canny算子检测边缘而不是Laplacian Zero-Crossing Operator。表达式如下:

\[l\left( p,q \right)={{w}_{E}}{{f}_{E}}\left( q \right)+{{w}_{G}}{{f}_{G}}\left( q \right)+{{w}_{D}}{{f}_{D}}\left( p,q \right)\]

Canny算子

基本思想是根据梯度信息,先检测出许多连通的像素,然后对于每一坨连通的像素只取其中最大值且连通的部分,将周围置零,得到初始的边缘(Edges),这个过程叫做Non-Maximum Suppression。然后用二阈值的办法将这些检测到的初始边缘分为Strong, Weak, and None三个等级,顾名思义,Strong就是很确定一定是边缘了,None就被舍弃,然后从Weak中挑选和Strong连通的作为保留的边缘,得到最后的结果,这个过程叫做Hysteresis Thresholding。这个算法太经典,更多细节一Google出来一大堆,我就不赘述了。公式如下:

\[{{f}_{E}}\left( q \right)=\left\{ \begin{matrix}
   0;\text{ if }q\text{ is on a edge}  \\
   1;\text{ if }q\text{ is not on a edge}  \\
\end{matrix} \right.\]

其实从权值的计算上和最大梯度有些重复,因为如果沿着最大梯度方向找出来的路径基本上就是边缘,这一项的作用我的理解主要应该是1) 避免梯度都很大的区域出现离明显边缘的偏离;2) 保证提取边缘的连续性,一定程度上来讲也是保证平滑。

梯度强度

就是梯度求模而已,x和y两个方向的梯度值平方相加在开方,公式如下:

\[{{I}_{G}}\left( q \right)=\sqrt{{{I}_{x}}\left( q \right)+{{I}_{y}}\left( q \right)}\]

因为要求cost,所以反向并归一化:

\[{{f}_{G}}\left( q \right)=1-\frac{{{I}_{G}}\left( q \right)}{\max \left( {{I}_{G}} \right)}\]

梯度方向

这一项其实是个平滑项,会给变化剧烈的边缘赋一个比较高的cost,让提取的边缘避免噪声的影响。具体公式如下:

\[{{f}_{D}}\left( p,q \right)=\frac{2}{3\pi }\left( \arccos \left( {{d}_{p}}\left( p,q \right) \right)+\arccos \left( {{d}_{q}}\left( p,q \right) \right) \right)\]

其中,

\[{{d}_{p}}\left( p,q \right)=\left\langle {{d}_{\bot }}\left( p \right),{{l}_{D}}\left( p,q \right) \right\rangle \]

\[{{d}_{q}}\left( p,q \right)=\left\langle {{l}_{D}}\left( p,q \right),{{d}_{\bot }}\left( q \right) \right\rangle \]

\[{{l}_{D}}\left( p,q \right)=\left\{ \begin{matrix}
   q-p;\text{ if }\left\langle {{d}_{\bot }}\left( p \right),q-p \right\rangle \ge 0  \\
   p-q;\text{ if }\left\langle {{d}_{\bot }}\left( p \right),q-p \right\rangle 29a9fd1e78ecc91c3b69746be10975303\),所以正上方的才是最小cost的正确方向。

最短路径查找

在磁力套索中,一般的用法是先单击一个点,然后移动鼠标,在鼠标和一开始单击的点之间就会出现自动贴近边缘的线,这里我们定义一开始单击的像素点为种子点(seed),而磁力套索其实在考虑上部分提到的边缘相关cost的情况下查找种子点到当前鼠标的最短路径。如下图,红色的就是种子点,而移动鼠标时,最贴近边缘的种子点和鼠标坐标的连线就会实时显示,这也是为什么磁力套索也叫Livewire。

Photoshop中磁力套索的一种简陋实现(基于Python)

实现最短路径的办法很多,一般而言就是动态规划了,这里介绍的是基于Dijkstra算法的一种实现,基本思想是,给定种子点后,执行Dijkstra算法将图像的所有像素遍历,得到每个像素到种子点的最短路径。以下面这幅图为例,在一个cost矩阵中,利用Dijkstra算法遍历每一个元素后,每个元素都会指向一个相邻的元素,这样任意一个像素都能找到一条到seed的路径,比如右上角的42和39对应的像素,沿着箭头到了0。

Photoshop中磁力套索的一种简陋实现(基于Python)

算法如下:

输入:
  s              // 种子点
  l(q,r)         // 计算局部cost

数据结构:
  L             // 当前待处理的像素
  N(q)          // 当前像素相邻的像素
  e(q)          // 标记一个像素是否已经做过相邻像素展开的Bool函数
  g(q)          // 从s到q的总cost

输出:
  p             // 记录所有路径的map

算法:
  g(s)←0; L←s;                 // 将种子点作为第一点初始化
  while L≠Ø:                   // 遍历尚未结束
    q←min(L);                  // 取出最小cost的像素并从待处理像素中移除
    e(q)←TRUE;                 // 将当前像素记录为已经做过相邻像素展开
    for each r∈N(q) and not e(r):
      gtemp←g(q)+l(q,r);        // 计算相邻像素的总cost
      if r∈L and gtemp<g(r):    // 找到了更好的路径
        r←L; { from list.}     // 舍弃较大cost的路径
      else if r∉L:
        g(r)←gtemp;             // 记录当前找到的最小路径
        p(r)←q;
        L←r;                    // 加入待处理以试图寻找更短的路径


Photoshop中磁力套索的一种简陋实现(基于Python) 遍历的过程会优先经过cost最低的区域,如下图:

所有像素对应的到种子像素的最短路径都找到后,移动鼠标时就直接画出到seed的最短路径就可以了。

Python实现

算法部分直接调用了OpenCV的Canny函数和Sobel函数(求梯度),对于RGB的处理也很简陋,直接用梯度最大的值来近似。另外因为懒,cost map和path map都直接用了字典(dict),而记录展开过的像素则直接采用了集合(set)。GUI部分因为不会用QThread所以用了Python的threading,只有图像显示交互区域和状态栏提示,左键点击设置种子点,右键结束,已经提取的边缘为绿色线,正在提取的为蓝色线。

Photoshop中磁力套索的一种简陋实现(基于Python)

代码

算法部分

from __future__ import division
import cv2
import numpy as np

SQRT_0_5 = 0.70710678118654757

class Livewire():
    """
    A simple livewire implementation for verification using 
        1. Canny edge detector + gradient magnitude + gradient direction
        2. Dijkstra algorithm
    """
    
    def __init__(self, image):
        self.image = image
        self.x_lim = image.shape[0]
        self.y_lim = image.shape[1]
        # The values in cost matrix ranges from 0~1
        self.cost_edges = 1 - cv2.Canny(image, 85, 170)/255.0
        self.grad_x, self.grad_y, self.grad_mag = self._get_grad(image)
        self.cost_grad_mag = 1 - self.grad_mag/np.max(self.grad_mag)
        # Weight for (Canny edges, gradient magnitude, gradient direction)
        self.weight = (0.425, 0.425, 0.15)
        
        self.n_pixs = self.x_lim * self.y_lim
        self.n_processed = 0
    
    @classmethod
    def _get_grad(cls, image):
        """
        Return the gradient magnitude of the image using Sobel operator
        """
        rgb = True if len(image.shape) > 2 else False
        grad_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
        grad_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
        if rgb:
            # A very rough approximation for quick verification...
            grad_x = np.max(grad_x, axis=2)
            grad_y = np.max(grad_y, axis=2)
            
        grad_mag = np.sqrt(grad_x**2+grad_y**2)
        grad_x /= grad_mag
        grad_y /= grad_mag
        
        return grad_x, grad_y, grad_mag
    
    def _get_neighbors(self, p):
        """
        Return 8 neighbors around the pixel p
        """
        x, y = p
        x0 = 0 if x == 0 else x - 1
        x1 = self.x_lim if x == self.x_lim - 1 else x + 2
        y0 = 0 if y == 0 else y - 1
        y1 = self.y_lim if y == self.y_lim - 1 else y + 2
        
        return [(x, y) for x in xrange(x0, x1) for y in xrange(y0, y1) if (x, y) != p]
    
    def _get_grad_direction_cost(self, p, q):
        """
        Calculate the gradient changes refer to the link direction
        """
        dp = (self.grad_y[p[0]][p[1]], -self.grad_x[p[0]][p[1]])
        dq = (self.grad_y[q[0]][q[1]], -self.grad_x[q[0]][q[1]])
        
        l = np.array([q[0]-p[0], q[1]-p[1]], np.float)
        if 0 not in l:
            l *= SQRT_0_5
        
        dp_l = np.dot(dp, l)
        l_dq = np.dot(l, dq)
        if dp_l < 0:
            dp_l = -dp_l
            l_dq = -l_dq
        
        # 2/3pi * ...
        return 0.212206590789 * (np.arccos(dp_l)+np.arccos(l_dq))
    
    def _local_cost(self, p, q):
        """
        1. Calculate the Canny edges & gradient magnitude cost taking into account Euclidean distance
        2. Combine with gradient direction
        Assumption: p & q are neighbors
        """
        diagnol = q[0] == p[0] or q[1] == p[1]
        
        # c0, c1 and c2 are costs from Canny operator, gradient magnitude and gradient direction respectively
        if diagnol:
            c0 = self.cost_edges[p[0]][p[1]]-SQRT_0_5*(self.cost_edges[p[0]][p[1]]-self.cost_edges[q[0]][q[1]])
            c1 = self.cost_grad_mag[p[0]][p[1]]-SQRT_0_5*(self.cost_grad_mag[p[0]][p[1]]-self.cost_grad_mag[q[0]][q[1]])
            c2 = SQRT_0_5 * self._get_grad_direction_cost(p, q)
        else:
            c0 = self.cost_edges[q[0]][q[1]]
            c1 = self.cost_grad_mag[q[0]][q[1]]
            c2 = self._get_grad_direction_cost(p, q)
        
        if np.isnan(c2):
            c2 = 0.0
        
        w0, w1, w2 = self.weight
        cost_pq = w0*c0 + w1*c1 + w2*c2
        
        return cost_pq * cost_pq

    def get_path_matrix(self, seed):
        """
        Get the back tracking matrix of the whole image from the cost matrix
        """
        neighbors = []          # 8 neighbors of the pixel being processed
        processed = set()       # Processed point
        cost = {seed: 0.0}      # Accumulated cost, initialized with seed to itself
        paths = {}

        self.n_processed = 0
        
        while cost:
            # Expand the minimum cost point
            p = min(cost, key=cost.get)
            neighbors = self._get_neighbors(p)
            processed.add(p)

            # Record accumulated costs and back tracking point for newly expanded points
            for q in [x for x in neighbors if x not in processed]:
                temp_cost = cost[p] + self._local_cost(p, q)
                if q in cost:
                    if temp_cost < cost[q]:
                        cost.pop(q)
                else:
                    cost[q] = temp_cost
                    processed.add(q)
                    paths[q] = p
            
            # Pop traversed points
            cost.pop(p)
            
            self.n_processed += 1
        
        return paths

livewire.py

livewire.py

GUI部分

 from __future__ import division
import time
import cv2
from PyQt4 import QtGui, QtCore
from threading import Thread
from livewire import Livewire

class ImageWin(QtGui.QWidget):
    def __init__(self):
        super(ImageWin, self).__init__()
        self.setupUi()
        self.active = False
        self.seed_enabled = True
        self.seed = None
        self.path_map = {}
        self.path = []
        
    def setupUi(self):
        self.hbox = QtGui.QVBoxLayout(self)
        
        # Load and initialize image
        self.image_path = &#39;&#39;
        while self.image_path == &#39;&#39;:
            self.image_path = QtGui.QFileDialog.getOpenFileName(self, &#39;&#39;, &#39;&#39;, &#39;(*.bmp *.jpg *.png)&#39;)
        self.image = QtGui.QPixmap(self.image_path)
        self.cv2_image = cv2.imread(str(self.image_path))
        self.lw = Livewire(self.cv2_image)
        self.w, self.h = self.image.width(), self.image.height()
        
        self.canvas = QtGui.QLabel(self)
        self.canvas.setMouseTracking(True)
        self.canvas.setPixmap(self.image)
        
        self.status_bar = QtGui.QStatusBar(self)
        self.status_bar.showMessage(&#39;Left click to set a seed&#39;)
        
        self.hbox.addWidget(self.canvas)
        self.hbox.addWidget(self.status_bar)
        self.setLayout(self.hbox)
    
    def mousePressEvent(self, event):            
        if self.seed_enabled:
            pos = event.pos()
            x, y = pos.x()-self.canvas.x(), pos.y()-self.canvas.y()
            
            if x < 0:
                x = 0
            if x >= self.w:
                x = self.w - 1
            if y < 0:
                y = 0
            if y >= self.h:
                y = self.h - 1

            # Get the mouse cursor position
            p = y, x
            seed = self.seed
            
            # Export bitmap
            if event.buttons() == QtCore.Qt.MidButton:
                filepath = QtGui.QFileDialog.getSaveFileName(self, &#39;Save image audio to&#39;, &#39;&#39;, &#39;*.bmp\n*.jpg\n*.png&#39;)
                image = self.image.copy()
                
                draw = QtGui.QPainter()
                draw.begin(image)
                draw.setPen(QtCore.Qt.blue)
                if self.path_map:
                    while p != seed:
                        draw.drawPoint(p[1], p[0])
                        for q in self.lw._get_neighbors(p):
                            draw.drawPoint(q[1], q[0])
                        p = self.path_map[p]
                if self.path:
                    draw.setPen(QtCore.Qt.green)
                    for p in self.path:
                        draw.drawPoint(p[1], p[0])
                        for q in self.lw._get_neighbors(p):
                            draw.drawPoint(q[1], q[0])
                draw.end()
                
                image.save(filepath, quality=100)
            
            else:
                self.seed = p
                
                if self.path_map:
                    while p != seed:
                        p = self.path_map[p]
                        self.path.append(p)
                
                # Calculate path map
                if event.buttons() == QtCore.Qt.LeftButton:
                    Thread(target=self._cal_path_matrix).start()
                    Thread(target=self._update_path_map_progress).start()
                
                # Finish current task and reset
                elif event.buttons() == QtCore.Qt.RightButton:
                    self.path_map = {}
                    self.status_bar.showMessage(&#39;Left click to set a seed&#39;)
                    self.active = False
    
    def mouseMoveEvent(self, event):
        if self.active and event.buttons() == QtCore.Qt.NoButton:
            pos = event.pos()
            x, y = pos.x()-self.canvas.x(), pos.y()-self.canvas.y()

            if x < 0 or x >= self.w or y < 0 or y >= self.h:
                pass
            else:
                # Draw livewire
                p = y, x
                path = []
                while p != self.seed:
                    p = self.path_map[p]
                    path.append(p)
                
                image = self.image.copy()
                draw = QtGui.QPainter()
                draw.begin(image)
                draw.setPen(QtCore.Qt.blue)
                for p in path:
                    draw.drawPoint(p[1], p[0])
                if self.path:
                    draw.setPen(QtCore.Qt.green)
                    for p in self.path:
                        draw.drawPoint(p[1], p[0])
                draw.end()
                self.canvas.setPixmap(image)
    
    def _cal_path_matrix(self):
        self.seed_enabled = False
        self.active = False
        self.status_bar.showMessage(&#39;Calculating path map...&#39;)
        path_matrix = self.lw.get_path_matrix(self.seed)
        self.status_bar.showMessage(r&#39;Left: new seed / Right: finish&#39;)
        self.seed_enabled = True
        self.active = True
        
        self.path_map = path_matrix
    
    def _update_path_map_progress(self):
        while not self.seed_enabled:
            time.sleep(0.1)
            message = &#39;Calculating path map... {:.1f}%&#39;.format(self.lw.n_processed/self.lw.n_pixs*100.0)
            self.status_bar.showMessage(message)
        self.status_bar.showMessage(r&#39;Left: new seed / Right: finish&#39;)

gui.py

gui.py

主函数

import sys
from PyQt4 import QtGui
from gui import ImageWin

def main():
    app = QtGui.QApplication(sys.argv)
    window = ImageWin()
    window.setMouseTracking(True)
    window.setWindowTitle('Livewire Demo')
    window.show()
    window.setFixedSize(window.size())
    sys.exit(app.exec_())

if __name__ == '__main__':
    main()    

main.py

main.py

蛋疼地上传到了Github(传送门),欢迎fork。

效率的改进

因为这个代码的原型只是为了用C++开发之前的Python评估和验证,所以完全没考虑效率,执行速度是完全不行的,基本上400x400的图片就不能忍了……至于基于Python版本的效率提升我没有仔细想过,只是大概看来有这么几个比较明显的地方:

1) 取出当前最小cost像素操作

 p = min(cost, key=cost.get) 

这个虽然写起来很爽但显然是不行的,至少得用个min heap什么的。因为我是用dict同时表示待处理像素和cost了,也懒得想一下怎么和Python的heapq结合起来,所以直接用了粗暴省事的min()。

2) 梯度方向的计算

三角函数的计算应该是尽量避免的,另外原文可能是为了将值域扩展到>π所以把q-p也用上了,其实这一项本来权重就小,那怕直接用两个像素各自的梯度方向向量做点积然后归一化一下结果也是还行的。即使要用arccos,也可以考虑写个look-up table近似。当然我最后想说的是个人觉得其实这项真没那么必要,直接自适应spilne或者那怕三点均值平滑去噪效果就不错了。

3) 计算相邻像素的位置

如果两个像素相邻,则他们各自周围的8个相邻像素也会重合。的我的办法比较原始,可以考率不用模块化直接计算。

4) 替换部分数据结构

比如path map其实本质是给出每个像素在最短路径上的上一个像素,是个矩阵。其实可以考虑用线性的数据结构代替,不过如果真这样做一般来说都是在C/C++代码里了。

5) numpy

我印象中对numpy的调用顺序也会影响到效率,连续调用numpy的内置方法似乎会带来效率的整体提升,不过话还是说回来,实际应用中如果到了这一步,应该也属于C/C++代码范畴了。

6) 算法层面的改进

这块没有深入研究,第一感觉是实际应用中没必要一上来就计算整幅图像,可以根据seed位置做一些区块划分,鼠标本身也会留下轨迹,也或许可以考虑只在鼠标轨迹方向进行启发式搜索。另外计算路径的时候也许可以考虑借鉴有点类似于Image Pyramid的思想,没必要一上来就对全分辨率下的路径进行查找。由于后来做的项目没有采用这个算法,所以我也没有继续研究,虽然挺好奇的,其实有好多现成的代码,比如GIMP,不过没有精力去看了。

更多的改进

虽然都没做,大概介绍一下,都是考虑了实用性的改进。

路径冷却(Path Cooling)

用过Photoshop和GIMP磁力套索的人都知道,即使鼠标不点击图片,在移动过程中也会自动生成一些将抠图轨迹固定住的点,这些点其实就是新的种子点,而这种使用过程中自动生成新的种子点的方法叫Path cooling。这个方法的基本思路如下:随着鼠标移动过程中如果一定时间内一段路径都保持固定不变,那么就把这段路径中离种子最远的点设置为新的种子,其实背后隐藏的还是动态规划的思想,贝尔曼最优。这个名字也是比较形象的,路径冷却。

动态训练(Interactive Dynamic Training)

Photoshop中磁力套索的一种简陋实现(基于Python)

单纯的最短路径查找在使用的时候常常出现找到的边缘不是想要的边缘的问题,比如上图,绿色的线是上一段提取的边缘,蓝色的是当前正在提取的边缘。左图中,镜子外面Lena的帽子边缘是我们想要提取的,然而由于镜子里的Lena的帽子边缘的cost更低,所以实际提取出的蓝色线段如右图中贴到右边了。所以Interactive Dynamic Training的思想是,认为绿色的线段是正确提取的边缘,然后利用绿色线段作为训练数据来给当前提取边缘的cost函数附加一个修正值。

[1]中采用的方法是统计前一段边缘上点的梯度强度的直方图,然后按照梯度出现频率给当前图中的像素加权。举例来说如果绿色线段中的所有像素对应的梯度强度都是在50到100之间的话,那么可以将50到100以10为单位分为5个bin,统计每个bin里的出现频率,也就是直方图,然后对当前检测到的梯度强度,做个线性加权。比方说50~60区间内对应的像素最多有10个,那么把10作为最大值,并且对当前检测到的梯度强度处于50~60之间的像素均乘上系数1.0;如果训练数据中70~80之间有5个,那么cost加权系数为5/10=0.5,则对所有当前检测到的梯度强度处于70~80之间的像素均乘上系数0.5;如果训练数据中100以上没有,所以cost附加为0/10=0,则加权系数为0,这样即使检测到更强的边缘也不会偏离前一段边缘了。这是基本思想,当然实际的实现没有这么简单,除了边缘上的像素还要考虑垂直边缘上左边和右边的两个像素点,这样保证了边缘的pattern。另外随着鼠标越来越远离训练边缘,检测到的边缘的pattern可能会出现不一样,所以Training可能会起反作用,所以这种Training的作用范围也需要考虑到鼠标离种子点的距离,最后还要有一些平滑去噪的处理,具体都在[1]里有讲到,挺繁琐的(那会好像还没有SIFT),不详述了。

种子点位置的修正(Cursor Snap)

虽然这个算法可以自动找出种子点和鼠标之间最贴近边缘的路径,不过,人的手,常常抖,所以种子点未必能很好地设置到边缘上。所以可以在用户设置完种子点位置之后,自动在其坐标周围小范围内,比如7x7的区域内搜索cost最低的像素,作为真正的种子点位置,这个过程叫做Cursor snap。

更多Photoshop中磁力套索的一种简陋实现(基于Python) 相关文章请关注PHP中文网!


声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn