天天看點

小張學linux核心之驅動篇:1. netlink套接字及domain socket的使用

netlink套接字

還曾記得學驅動開發時,使用mdev和udev工具在/dev下自動生成節點吧。mdev/udev接受熱插拔的事件進而生成節點。愛鑽研的小夥伴已經知道這個uevent事件在uevent.c檔案裡,mdev和udev的不同,mdev是以回調鈎子的方式調用的,而udev則是作為一個守護程序,通過netlink socket接收uevent事件的。那下面我們來實踐一下netlink socket的使用。

核心測代碼

環境:linux kernel版本4.9.9

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/types.h>
#include <linux/sched.h>
#include <net/sock.h>
#include <linux/netlink.h>
#include <linux/kthread.h>
#include <linux/sched.h>
#include <linux/err.h>
#include <linux/fs.h>
#include <linux/init.h>
#include <linux/cdev.h>

#define MEMDEV_MAJOR 255 /* 預設的mem的主裝置号 */
#define MEMDEV_NR_DEVS 1 /* 裝置數 */
#define MEMDEV_SIZE 1024 /* 配置設定記憶體的大小 */
#define NETLINK_TEST 17
#define UP_TO_LOW 0
#define LOW_TO_UP 1
#define MAX_PID_COUNT 100
#define MSG_LEN 125
#ifndef SLEEP_MILLI_SEC
#define SLEEP_MILLI_SEC(nMillisec) \
do { \
long timeout = (nMillisec) * HZ /1000; \
while (timeout > 0) { \
timeout = schedule_timeout(timeout); \
} \
}while(0);
#endif
static int echo_major = MEMDEV_MAJOR;
module_param(echo_major, int, S_IRUGO);
struct echo_dev *echo_devp; /*裝置結構體指針*/
struct cdev cdev;
char *echo_dev = "echodev";
static struct sock *netlinkfd = NULL;
static struct task_struct *task_test[MAX_PID_COUNT];
static int pid_index = 0;
static int char_num = 0;
static int char_cnvt_flag = 0;
/* mem裝置描述結構體 */
struct echo_dev
{
    char *data; /* 配置設定到的記憶體的起始位址 */
    unsigned long size; /* 記憶體的大小 */
};
struct
{
    __u32 pid;
} user_process;
/* netlink */
struct echo_netlink
{
    __u32 pid; /* netlink pid */
    char buf[MSG_LEN]; /* data */
    int length; /* buf len */
};
struct echo_netlink client_netlink[MAX_PID_COUNT];
static int echo_open(struct inode *inode, struct file *filp);
static ssize_t echo_read(struct file *filp, char __user *buf, size_t size, loff_t *ppos);
static long echo_ioctl(struct file *filp, unsigned int cmd, unsigned long arg);
static const struct file_operations echo_fops =
{
    .owner = THIS_MODULE,
    .open = echo_open,
    .read = echo_read,
    .unlocked_ioctl = echo_ioctl,
};
static int echo_open(struct inode *inode, struct file *filp)
{
    /*擷取次裝置号*/
    printk(KERN_DEBUG"[kernel space] open char device!!\n");
    return 0;
}
static ssize_t echo_read(struct file *filp, char __user *buf, size_t size, loff_t *ppos)
{
    printk(KERN_DEBUG"[kernel space] test_netlink_exit!!\n");
    return char_num;
}
static long echo_ioctl(struct file *filp, unsigned int cmd, unsigned long arg)
{
    int result = 0;
    switch(cmd)
        {
        case UP_TO_LOW:
            char_cnvt_flag = 0;
            break;
        case LOW_TO_UP:
            char_cnvt_flag = 1;
            break;
        default :
            result = -1;
            break;
        }
    printk(KERN_DEBUG"[kernel space] ioctl cmd: %d\n",char_cnvt_flag);
    return result;
}
int init_char_device(void)
{
    int i,result;
    dev_t devno = MKDEV(echo_major, 0);
    if (echo_major)
        /* 靜态申請裝置号*/
        result = register_chrdev_region(devno, 2, "echodev");
    else
        {
            /* 動态配置設定裝置号 */
            result = alloc_chrdev_region(&devno, 0, 2, "echodev");
            echo_major = MAJOR(devno);
        }
    if ( result<0 )
        return result;
    /* 初始化cdev結構 */
    cdev_init(&cdev, &echo_fops);
    cdev.owner = THIS_MODULE;
    cdev.ops = &echo_fops;
    /* 注冊字元裝置 */
    cdev_add(&cdev, MKDEV(echo_major, 0), MEMDEV_NR_DEVS);
    /* 為裝置描述結構配置設定記憶體 */
    echo_devp = kmalloc(MEMDEV_NR_DEVS * sizeof(struct echo_dev), GFP_KERNEL);
    /* 申請失敗 */
    if (!echo_devp)
        {
            result = -1;
            goto fail_malloc;
        }
    memset(echo_devp, 0, sizeof(struct echo_dev));
    /* 為裝置配置設定記憶體 */
    for(i= 0; i < MEMDEV_NR_DEVS; i++)
        {
            echo_devp[i].size = MEMDEV_SIZE;
            echo_devp[i].data = kmalloc(MEMDEV_SIZE, GFP_KERNEL);
            memset(echo_devp[i].data, 0, MEMDEV_SIZE);
        }
    printk(KERN_ERR"[kernel space] create char device successfuly!\n");
    return 0;
fail_malloc:
    unregister_chrdev_region(devno, 1);
    return result;
}
void delete_device(void)
{
    /* 登出裝置 */
    cdev_del(&cdev);
    /* 釋放裝置号 */
    unregister_chrdev_region(MKDEV(echo_major, 0), 2);
    printk(KERN_DEBUG"[kernel space] echo_cdev_del!!\n");
}
static int kernel_send_thread(void *index)
{
    int threadindex = *((int *)index);
    int size;
    struct sk_buff *skb;
    unsigned char *old_tail;
    struct nlmsghdr *nlh; //封包頭
    int retval;
    int i=0;
    size = NLMSG_SPACE(client_netlink[threadindex].length);
    /* 配置設定一個新的套接字緩存,使用GFP_ATOMIC标志程序不會被置為睡眠 */
    skb = alloc_skb(size, GFP_ATOMIC);
    /* 初始化一個netlink消息首部 */
    nlh = nlmsg_put(skb, 0, 0, 0, NLMSG_SPACE(client_netlink[threadindex].length)-sizeof(struct nlmsghdr), 0);
    old_tail = skb->tail;
//memcpy(NLMSG_DATA(nlh), client_netlink[i].buf, client_netlink[i].length); //填充資料區
    strcpy(NLMSG_DATA(nlh), client_netlink[threadindex].buf); //填充資料區
    nlh->nlmsg_len = skb->tail - old_tail; //設定消息長度
    /* 設定控制字段 */
    NETLINK_CB(skb).nsid = 0;
    NETLINK_CB(skb).dst_group = 0;
    printk(KERN_DEBUG "[kernel space] send to user: %s, send_pid: %d, send_len: %d\n", \
           (char *)NLMSG_DATA((struct nlmsghdr *)skb->data), client_netlink[threadindex].pid, \
           client_netlink[threadindex].length);
    /* 發送資料 */
    retval = netlink_unicast(netlinkfd, skb, client_netlink[threadindex].pid, MSG_DONTWAIT);
    if (retval<0)
        {
            printk(KERN_DEBUG "[kernel space] client closed: \n");
        }
    while(!(i = kthread_should_stop()))
        {
            printk(KERN_DEBUG "[kernel space] kthread_should_stop: %d\n", i);
            SLEEP_MILLI_SEC(1000*10);
        }
    return 0;
}
void char_convert(int id)
{
    int len = client_netlink[id].length;
    int i = 0;
    client_netlink[id].buf[len] = '\0';
    if( UP_TO_LOW == char_cnvt_flag )
        {
            printk(KERN_DEBUG "[kernel space] UP_TO_LOW\n");
            while(client_netlink[id].buf[i] != '\0')
                {
                    if(client_netlink[id].buf[i] >= 'A' && client_netlink[id].buf[i] <= 'Z')
                        {
                            client_netlink[id].buf[i] = client_netlink[id].buf[i] + 0x20;
                            mdelay(200);
                        }
                    i++;
                }
        }
    else if( LOW_TO_UP == char_cnvt_flag )
        {
            printk(KERN_DEBUG "[kernel space] LOW_TO_UP\n");
            while(client_netlink[id].buf[i] != '\0')
                {
                    if(client_netlink[id].buf[i] >= 'a' && client_netlink[id].buf[i] <= 'z')
                        {
                            client_netlink[id].buf[i] = client_netlink[id].buf[i] - 0x20;
                            mdelay(200);
                        }
                    i++;
                }
        }
    char_num += len;
}
void run_netlink_thread(int thread_index)
{
    int err;
    char process_name[64] = {0};
    void* data = kmalloc(sizeof(int), GFP_ATOMIC);
    *(int *)data = thread_index;
    snprintf(process_name, 63, "child_thread-%d", thread_index);
    task_test[thread_index] = kthread_create(kernel_send_thread, data, process_name);
    if(IS_ERR(task_test[thread_index]))
        {
            err = PTR_ERR(task_test[thread_index]);
            printk(KERN_DEBUG "[kernel space] creat child thread failure \n");
        }
    else
        {
            printk(KERN_DEBUG "[kernel space] creat child_thread-%d \n", thread_index);
            wake_up_process(task_test[thread_index]);
        }
}
void buf_deal(int id)
{
    char_convert(id);
    /* 喚醒線程 */
    run_netlink_thread(id);
}
void kernel_recv_thread(struct sk_buff *__skb)
{
    struct sk_buff *skb;
    struct nlmsghdr *nlh = NULL;
    char *recv_data = NULL;
    int pid_id = 0;
    printk(KERN_DEBUG "[kernel space] begin kernel_recv\n");
    skb = skb_get(__skb);
    if(skb->len >= NLMSG_SPACE(0))
        {
            nlh = nlmsg_hdr(skb);
            if(pid_index < MAX_PID_COUNT)
                {
                    client_netlink[pid_index].pid = nlh->nlmsg_pid;
                    recv_data = NLMSG_DATA(nlh);
                    strcpy(client_netlink[pid_index].buf,recv_data);
                    client_netlink[pid_index].length = strlen(recv_data);
                    printk(KERN_DEBUG "[kernel space] recv from user: %s, recv_pid: %d, recv_len: %d\n", \
                           (char *)NLMSG_DATA(nlh), client_netlink[pid_index].pid, strlen(recv_data));
                    pid_id = pid_index;
                    pid_index++;
                    buf_deal(pid_id);
                }
            else
                {
                    printk(KERN_DEBUG "[kernel space] out of pid\n");
                }
            kfree_skb(skb);
        }
}
int init_netlink(void)
{
	struct netlink_kernel_cfg nl_sock_cfg;
	nl_sock_cfg.input = kernel_recv_thread;
	
    netlinkfd = netlink_kernel_create(&init_net, NETLINK_TEST, &nl_sock_cfg);//(&init_net,NETLINK_TEST,0,kernel_recv_thread,NULL,THIS_MODULE);  linux 2.6的參數
    if(!netlinkfd )
        return -1;
    else
        {
            printk(KERN_ERR"[kernel space] create netlink successfuly!\n");
            return 0;
        }
}
void netlink_release(void)
{
    printk(KERN_DEBUG"[kernel space] echo_netlink_exit!\n");
    if(netlinkfd != NULL)
        sock_release(netlinkfd->sk_socket);
}
void stop_kthread(void)
{
    int i;
    printk(KERN_ERR"[kernel space] stop kthread!\n");
    for(i=0; i != pid_index; i++)
        {
            if(task_test[i] != NULL)
                {
                    kthread_stop(task_test[i]);
                    task_test[i] = NULL;
                }
        }
}
void init_client(void)
{
    int i = 0;
    for(i=0; i<MAX_PID_COUNT; i++)
        {
            client_netlink[i].pid = 0;
            task_test[i] = NULL;
        }
}
/**
* NAME: init_echo_module
*
* DESCRIPTION:
3
3
* 子產品加載函數
* @*psdhdr
* @*addr
* @size
*
* RETURN:
*/
int __init init_echo_module(void)
{
    int result = 0;
    init_client();
    result = init_char_device();
    if ( result<0 )
        {
            printk(KERN_ERR"[kernel space] cannot create a netlinksocket!\n");
            return result;
        }
    result = init_netlink();
    if ( result<0 )
        {
            printk(KERN_ERR"[kernel space] cannot create a netlinksocket!\n");
            return result;
        }
    return result;
}
/**
* NAME: exit_echo_module
*
* DESCRIPTION:
* 子產品解除安裝函數
* @*psdhdr
* @*addr
* @size
*
* RETURN:
*/
void __exit exit_echo_module(void)
{
    netlink_release();
    stop_kthread();
    delete_device();
}
module_init(init_echo_module);
module_exit(exit_echo_module);
MODULE_LICENSE("GPL");
MODULE_AUTHOR("zhang");
MODULE_VERSION("V1.0");
           

使用者側代碼

#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <string.h>
#include <sys/time.h>
#include <linux/netlink.h>
#include <signal.h>
#include <errno.h>

#define BUF_LEN 		125
#define NETLINK_TEST	17
#define MSG_LEN			125
#define TIME  			210

int skfd;
struct sockaddr_nl local;
struct sockaddr_nl dest;
struct nlmsghdr *message;


struct u_packet_info
{
	struct nlmsghdr hdr;
	char msg[MSG_LEN];
};

static void sig_pipe(int sign)
{
	printf("Catch a SIGPIPE signal!\n");
	close(skfd);
	kill(local.nl_pid, SIGUSR1);
	exit(-1);
}

int init_netlink(void)
{
	char send_data[BUF_LEN];
	message = (struct nlmsghdr*)malloc(sizeof(struct nlmsghdr));
	skfd = socket(PF_NETLINK, SOCK_RAW, NETLINK_TEST);
	if (skfd < 0)
	{
		printf("can not create a netlink socket! errno = %d\n", errno);
		return -1;
	}

	memset(&local, 0, sizeof(local));
	local.nl_family = AF_NETLINK;
	local.nl_pid 	= getpid();
	local.nl_groups	= 0;

	if (bind(skfd, (struct sockaddr*)&local, sizeof(local)) != 0)
	{
		printf("bind() error!\n");
		return -1;
	}

	memset((char *)&dest, 0, sizeof(dest));
	dest.nl_family 	= AF_NETLINK;
	dest.nl_pid		= 0;
	dest.nl_groups	= 0;

	memset(message, '\0', sizeof(struct nlmsghdr)); 
	message->nlmsg_len 		= NLMSG_SPACE(MSG_LEN);
	message->nlmsg_flags 	= 0;
	message->nlmsg_type		= 0;
	message->nlmsg_seq		= 0;
	message->nlmsg_pid		= local.nl_pid;

	while (1)
	{
		printf("input the data:");
		fgets(send_data, MSG_LEN, stdin);
		if (0 == (strlen(send_data)))
		{
			continue;
		}
		else
		{
			break;
		}
	} 
	memcpy(NLMSG_DATA(message), send_data, strlen(send_data) - 1);
	printf("send to kernel: %s, send_id: %d send_len: %d\n",\
			 (char *)NLMSG_DATA(message), local.nl_pid, strlen(send_data) - 1);
	return 0;
}

int main(int argc, char **argv)
{
	int ret;
	int len;
	fd_set fd_sets;
	socklen_t destlen = sizeof(struct sockaddr_nl);
	struct timeval select_time;
	struct u_packet_info info;
	signal(SIGINT, sig_pipe);

	ret = init_netlink();
	if (ret < 0)
	{
		close(skfd);
		perror("netlink failure!");
		exit(-1);
	}

	FD_ZERO(&fd_sets);
	FD_SET(skfd, &fd_sets);

	len = sendto(skfd, message, message->nlmsg_len, 0, (struct sockaddr*)&dest, sizeof(dest));
	if (!len)
	{
		perror("send pid:");
		exit(-1);
	}
	select_time.tv_sec = TIME;
	select_time.tv_usec = 0;
	ret = select(skfd+1, &fd_sets, NULL, NULL, &select_time);
	if (ret > 0)
	{
		len = recvfrom(skfd, &info, sizeof(struct u_packet_info), 0, (struct sockaddr*)&dest, &destlen);
		printf("recv from kernel:%s, recv_len: %d\n", (char *)info.msg, strlen(info.msg));
	}
	else if (ret < 0)
	{
		perror("\n error!\n");
		exit(-1);
	}
	else
	{
		printf("\n kernel server disconnect!\n");
		kill(local.nl_pid, SIGUSR1);
	}
	close(skfd);
	return 0;
}
           

domain socket

不同于netlink socket是核心與使用者通訊的,domain socketdomain socket是使用者側兩程序之間通訊的,在lgui中視窗之間通信等場景中會使用到,其實類似與管道,是使用檔案來傳輸的,廢話不多說,上代碼。

用戶端:

#include <stdio.h>
#include <stdlib.h>
#include <stddef.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <errno.h>


int cli_con(const char *name)
{
	int cli_fd;
	int ret = 0;
	int err;
	struct sockaddr_un local_un;
	struct sockaddr_un serv_un;
	int localadr_len;
	int servadr_len;
	cli_fd = socket(AF_UNIX, SOCK_STREAM, 0);
	if (cli_fd < 0)
	{
		return -1;
	}
	memset(&local_un, 0, sizeof(struct sockaddr_un));
	local_un.sun_family = AF_UNIX;
	strcpy(local_un.sun_path, "cli.socket");
	localadr_len = offsetof(struct sockaddr_un, sun_path) + strlen(local_un.sun_path);
	/*先删除*/
	unlink(local_un.sun_path);
	ret = bind(cli_fd, (struct sockaddr*)&local_un, localadr_len);
	if (ret < 0)
	{
		ret = -2;
		goto _ERR;
	}

	memset(&serv_un, 0, sizeof(struct sockaddr_un));
	serv_un.sun_family = AF_UNIX;
	strcpy(serv_un.sun_path, name);
	servadr_len = offsetof(struct sockaddr_un, sun_path) + strlen(serv_un.sun_path);
	
	ret = connect(cli_fd, (struct sockaddr*)&serv_un, servadr_len);
	if (ret < 0)
	{
		ret = -4;
		goto _ERR;
	}
	return cli_fd;
_ERR:
	err = errno;
	close(cli_fd);
	errno = err;
	return ret;
}


int main()
{
	int cli_fd;
	char send_buf[1024];
	int recv_len;
	cli_fd = cli_con("foo.socket");
	 if (cli_fd < 0)
    {
        switch (cli_fd)
            {
            case -4:
                perror("connect");
                break;
            case -3:
                perror("listen");
                break;
            case -2:
                perror("bind");
                break;
            case -1:
                perror("socket");
                break;
            }
        exit(-1);
    }
    memset(send_buf, 0, sizeof(send_buf));
    while(fgets(send_buf, sizeof(send_buf), stdin) != 0)
    {
    	write(cli_fd, send_buf, sizeof(send_buf));
    	memset(send_buf, 0, sizeof(send_buf));
    	recv_len = read(cli_fd, send_buf, sizeof(send_buf));
    	write(STDOUT_FILENO, send_buf, recv_len);
    }
    close(cli_fd);
    return -1;
}
           

伺服器端:

#include <stdio.h>
#include <stdlib.h>
#include <stddef.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <errno.h>



int serv_listen(char *name)
{
    struct sockaddr_un un;
    int fd;
    int len;
    int ret = 0;
    int err;
    fd = socket(AF_UNIX, SOCK_STREAM, 0);
    if (fd < 0)
    {
        return -1;
    }
    /*先删除,否則bind會出錯*/
    unlink(name);
    memset(&un, 0, sizeof(struct sockaddr_un));
    un.sun_family = AF_UNIX;
    strcpy(un.sun_path, name);
    len = offsetof(struct sockaddr_un, sun_path) + strlen(name);

    ret = bind(fd, (struct sockaddr*)&un, len);
    if (ret < 0)
    {
        ret = -2;
        goto _ERR;
    }
    ret = listen(fd, 10);
    if (ret < 0)
    {
        ret = -3;
        goto _ERR;
    }

    return fd;
_ERR:
    err = errno;
    close(fd);
    errno = err;
    return ret;
}

int serv_accept(int listenfd, uid_t *uidptr)
{
    int ret = 0;
    int err;
    int cli_fd;
    struct sockaddr_un cli_un;
    struct stat statbuf;
    int len;
    len = sizeof(cli_un);

    cli_fd = accept(listenfd, (struct sockaddr*)&cli_un, &len);
    if (cli_fd < 0)
    {
        return -1;
    }

    len = len - offsetof(struct sockaddr_un, sun_path) ;
    cli_un.sun_path[len] = 0;   /*末尾補零*/
    ret = stat(cli_un.sun_path, &statbuf);
    if (ret < 0)
    {
        ret = -2;
        goto _ERR;
    }
    if (S_ISSOCK(statbuf.st_mode) == 0)
    {
        ret = -3;
        goto _ERR;
    }
    if (uidptr != NULL)
    {
        *uidptr = statbuf.st_uid;
    }

    unlink(cli_un.sun_path);
    return cli_fd;
_ERR:
    err = errno;
    close(cli_fd);
    errno = err;
    return(ret);
}

int main()
{

    int listen_fd;
    int con_fd;
    uid_t cuid;
    char recv_buf[1024];
    int recved_len;
    int i;
    listen_fd = serv_listen("foo.socket");
    if (listen_fd < 0)
    {
        switch (listen_fd)
            {
            case -3:
                perror("listen");
                break;
            case -2:
                perror("bind");
                break;
            case -1:
                perror("socket");
                break;
            }
        exit(-1);
    }

    con_fd = serv_accept(listen_fd, &cuid);
    if (con_fd < 0)
    {
        switch (con_fd)
            {
            case -3:
                perror("not a socket");
                break;
            case -2:
                perror("a bad filename");
                break;
            case -1:
                perror("accept");
                break;
            }
        exit(-1);
    }
    printf("accept successed!\n");
    while(1)
    {
        recved_len = read(con_fd, recv_buf, 1024);
        if (recved_len == -1)
        {
            if (EINTR == errno)
                continue;
        }
        else if (recved_len == 0)
        {
            printf("the other side has been closed.\n");
                break;
        }

        for (i = 0; i < recved_len; i++)
        {
            recv_buf[i] = toupper(recv_buf[i]);
        }

        write(con_fd, recv_buf, recved_len);
    }
    close(con_fd);
    close(listen_fd);
}

           

繼續閱讀