Ich habe nachgeschaut, kann aber anscheinend keine Beispiele dafür finden, wie man einen One-Hot-Wert in TensorFlow dekodiert oder zurück in eine einzelne Ganzzahl konvertiert.
Ich habe tf.one_hot
verwendet und konnte mein Modell trainieren, bin aber etwas verwirrt darüber, wie ich das Etikett nach meiner Klassifizierung verstehen soll. Meine Daten werden über eine von mir erstellte TFRecords
-Datei eingespeist. Ich habe darüber nachgedacht, ein Textetikett in der Datei zu speichern, konnte es aber nicht zum Laufen bringen. Es sah so aus, als ob TFRecords
keine Textzeichenfolge speichern könnte oder ich mich geirrt habe.
Sie können den Index des größten Elements in der Matrix mit tf.argmax
ermitteln. Da Ihr einziger heißer Vektor eindimensional ist und nur einen 1
und einen anderen 0
hat, funktioniert dies unter der Annahme, dass Sie sich mit einem einzelnen Vektor befassen.
index = tf.argmax(one_hot_vector, axis=0)
Für die Standardmatrix von batch_size * num_classes
verwenden Sie axis=1
, um ein Ergebnis der Größe batch_size * 1
zu erhalten.
Da eine One-Hot-Codierung in der Regel nur eine Matrix mit batch_size
Zeilen und num_classes
Spalten ist und jede Zeile Null ist, wobei eine einzige Nicht-Null der ausgewählten Klasse entspricht, können Sie tf.argmax()
verwenden, um einen Ganzzahlvektor wiederherzustellen Etiketten:
BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1]])
# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)
# ...
print sess.run(decoded) # ==> array([1, 0, 3])
data = np.array([1, 5, 3, 8])
print(data)
def encode(data):
print('Shape of data (BEFORE encode): %s' % str(data.shape))
encoded = to_categorical(data)
print('Shape of data (AFTER encode): %s\n' % str(encoded.shape))
return encoded
encoded_data = encode(data)
print(encoded_data)
def decode(datum):
return np.argmax(datum)
decoded_Y = []
print("****************************************")
for i in range(encoded_data.shape[0]):
datum = encoded_data[i]
print('index: %d' % i)
print('encoded datum: %s' % datum)
decoded_datum = decode(encoded_data[i])
print('decoded datum: %s' % decoded_datum)
decoded_Y.append(decoded_datum)
print("****************************************")
print(decoded_Y)