AI&BigData/Deep Learning
Lab05. Logistic classification
eunguru
2018. 4. 18. 11:11
Logistic (regression) classification
1. Logistic regression
Logistic regression은 이진분류(binary classification) 문제를 해결하기 위한 모델
1) Binary classification
두 개중 한개를 정해진 카테고리에서 고르는 것
- 스팸메일 탐지: spam / ham
- 페이스북 피드: show / hide
- 신용카드 부정거래 탐지: legitimate / fraud
0,1 encoding: 기계적인 학습을 위해 0 또는 1로 변환 필요
2) Binary logistic classification의 hypothesis, cost function
Binary classification은 0또는 1의 값을 가져야 하기때문에 linear regression에서 사용한 cost함수는 사용 불가
압축을 위해 sigmoid 함수 사용
- 학습은 동일하게 gradient decent optimizer를 사용
2. Example Code
1) Logistic Classification
소스코드
import tensorflow as tf x_data = [[1, 2], [2, 3], [3, 1], [4, 3], [5, 3], [6, 2]] y_data = [[0], [0], [0], [1], [1], [1]] X = tf.placeholder(tf.float32, shape=[None, 2]) Y = tf.placeholder(tf.float32, shape=[None, 1]) W = tf.Variable(tf.random_normal([2, 1]), name = 'weight') b = tf.Variable(tf.random_normal([1]), name = 'bias') # Hypothesis using sigmoid: tf.div(1., 1. + tf.exp(tf.matmul(X, W) + b)) hypothesis = tf.sigmoid(tf.matmul(X, W) + b) cost = -tf.reduce_mean(Y*tf.log(hypothesis) + (1 - Y)*tf.log(1 - hypothesis)) train = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost) predicted = tf.cast(hypothesis > 0.5, dtype = tf.float32) accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, Y), dtype=tf.float32)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for step in range(10001): cost_val, _ = sess.run([cost, train], feed_dict={X: x_data, Y: y_data}) if step % 200 == 0: print("step: ", step, "cost: ", cost_val) h, c, a = sess.run([hypothesis, predicted, accuracy], feed_dict={X: x_data, Y: y_data}) print("\nHypothesis: ", h, "\nCorrect (Y): ", c, "\nAccuracy: ", a)
결과
(...) step: 9000 cost: 0.15864347 step: 9200 cost: 0.15615107 step: 9400 cost: 0.15373564 step: 9600 cost: 0.15139394 step: 9800 cost: 0.14912277 step: 10000 cost: 0.14691907 Hypothesis: [[0.02966413] [0.1573633 ] [0.29967391] [0.78376436] [0.94106627] [0.9806702 ]] Correct (Y): [[0.] [0.] [0.] [1.] [1.] [1.]] Accuracy: 1.0